Compare commits
	
		
			1 Commits
		
	
	
		
			version/20
			...
			providers/
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 739acf50f4 | 
@ -35,7 +35,7 @@ type ProviderInstance struct {
 | 
			
		||||
	cert                *tls.Certificate
 | 
			
		||||
	certUUID            string
 | 
			
		||||
	outpostName         string
 | 
			
		||||
	outpostPk           int32
 | 
			
		||||
	providerPk          int32
 | 
			
		||||
	searchAllowedGroups []*strfmt.UUID
 | 
			
		||||
	boundUsersMutex     *sync.RWMutex
 | 
			
		||||
	boundUsers          map[string]*flags.UserFlags
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ import (
 | 
			
		||||
 | 
			
		||||
func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance {
 | 
			
		||||
	for _, p := range ls.providers {
 | 
			
		||||
		if p.outpostPk == pk {
 | 
			
		||||
		if p.providerPk == pk {
 | 
			
		||||
			return p
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
@ -83,7 +83,7 @@ func (ls *LDAPServer) Refresh() error {
 | 
			
		||||
			gidStartNumber:         provider.GetGidStartNumber(),
 | 
			
		||||
			mfaSupport:             provider.GetMfaSupport(),
 | 
			
		||||
			outpostName:            ls.ac.Outpost.Name,
 | 
			
		||||
			outpostPk:              provider.Pk,
 | 
			
		||||
			providerPk:             provider.Pk,
 | 
			
		||||
		}
 | 
			
		||||
		if kp := provider.Certificate.Get(); kp != nil {
 | 
			
		||||
			err := ls.cs.AddKeypair(*kp)
 | 
			
		||||
 | 
			
		||||
@ -6,8 +6,10 @@ import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
	"goauthentik.io/internal/outpost/ldap/flags"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func parseCIDRs(raw string) []*net.IPNet {
 | 
			
		||||
@ -29,6 +31,25 @@ func parseCIDRs(raw string) []*net.IPNet {
 | 
			
		||||
	return cidrs
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rs *RadiusServer) getCurrentProvider(pk int32) *ProviderInstance {
 | 
			
		||||
	for _, p := range rs.providers {
 | 
			
		||||
		if p.providerPk == pk {
 | 
			
		||||
			return p
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rs *RadiusServer) getInvalidationFlow() string {
 | 
			
		||||
	req, _, err := rs.ac.Client.CoreApi.CoreBrandsCurrentRetrieve(context.Background()).Execute()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		rs.log.WithError(err).Warning("failed to fetch brand config")
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	flow := req.GetFlowInvalidation()
 | 
			
		||||
	return flow
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rs *RadiusServer) Refresh() error {
 | 
			
		||||
	outposts, _, err := rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()).Execute()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@ -37,17 +58,33 @@ func (rs *RadiusServer) Refresh() error {
 | 
			
		||||
	if len(outposts.Results) < 1 {
 | 
			
		||||
		return errors.New("no radius provider defined")
 | 
			
		||||
	}
 | 
			
		||||
	invalidationFlow := rs.getInvalidationFlow()
 | 
			
		||||
	providers := make([]*ProviderInstance, len(outposts.Results))
 | 
			
		||||
	for idx, provider := range outposts.Results {
 | 
			
		||||
		logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name)
 | 
			
		||||
 | 
			
		||||
		// Get existing instance so we can transfer boundUsers
 | 
			
		||||
		existing := rs.getCurrentProvider(provider.Pk)
 | 
			
		||||
		usersMutex := &sync.RWMutex{}
 | 
			
		||||
		users := make(map[string]*flags.UserFlags)
 | 
			
		||||
		if existing != nil {
 | 
			
		||||
			usersMutex = existing.boundUsersMutex
 | 
			
		||||
			// Shallow copy, no need to lock
 | 
			
		||||
			users = existing.boundUsers
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		providers[idx] = &ProviderInstance{
 | 
			
		||||
			SharedSecret:           []byte(provider.GetSharedSecret()),
 | 
			
		||||
			ClientNetworks:         parseCIDRs(provider.GetClientNetworks()),
 | 
			
		||||
			MFASupport:             provider.GetMfaSupport(),
 | 
			
		||||
			appSlug:                provider.ApplicationSlug,
 | 
			
		||||
			flowSlug:       provider.AuthFlowSlug,
 | 
			
		||||
			authenticationFlowSlug: provider.AuthFlowSlug,
 | 
			
		||||
			invalidationFlowSlug:   invalidationFlow,
 | 
			
		||||
			s:                      rs,
 | 
			
		||||
			log:                    logger,
 | 
			
		||||
			providerPk:             provider.Pk,
 | 
			
		||||
			boundUsersMutex:        usersMutex,
 | 
			
		||||
			boundUsers:             users,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	rs.providers = providers
 | 
			
		||||
 | 
			
		||||
@ -4,15 +4,17 @@ import (
 | 
			
		||||
	"github.com/prometheus/client_golang/prometheus"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
	"goauthentik.io/internal/outpost/flow"
 | 
			
		||||
	"goauthentik.io/internal/outpost/ldap/flags"
 | 
			
		||||
	"goauthentik.io/internal/outpost/radius/metrics"
 | 
			
		||||
	"layeh.com/radius"
 | 
			
		||||
	"layeh.com/radius/rfc2865"
 | 
			
		||||
	"layeh.com/radius/rfc2866"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusRequest) {
 | 
			
		||||
	username := rfc2865.UserName_GetString(r.Packet)
 | 
			
		||||
 | 
			
		||||
	fe := flow.NewFlowExecutor(r.Context(), r.pi.flowSlug, r.pi.s.ac.Client.GetConfig(), log.Fields{
 | 
			
		||||
	fe := flow.NewFlowExecutor(r.Context(), r.pi.authenticationFlowSlug, r.pi.s.ac.Client.GetConfig(), log.Fields{
 | 
			
		||||
		"username":  username,
 | 
			
		||||
		"client":    r.RemoteAddr(),
 | 
			
		||||
		"requestId": r.ID(),
 | 
			
		||||
@ -64,5 +66,28 @@ func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusR
 | 
			
		||||
		}).Inc()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	_ = w.Write(r.Response(radius.CodeAccessAccept))
 | 
			
		||||
	// Get user info to store in context
 | 
			
		||||
	userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(r.Context()).Execute()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		metrics.RequestsRejected.With(prometheus.Labels{
 | 
			
		||||
			"outpost_name": rs.ac.Outpost.Name,
 | 
			
		||||
			"type":         "bind",
 | 
			
		||||
			"reason":       "user_info_fail",
 | 
			
		||||
		}).Inc()
 | 
			
		||||
		r.Log().WithError(err).Warning("failed to get user info")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	response := r.Response(radius.CodeAccessAccept)
 | 
			
		||||
	_ = rfc2866.AcctSessionID_SetString(response, fe.GetSession().String())
 | 
			
		||||
	r.pi.boundUsersMutex.Lock()
 | 
			
		||||
	r.pi.boundUsers[fe.GetSession().String()] = &flags.UserFlags{
 | 
			
		||||
		Session: fe.GetSession(),
 | 
			
		||||
		UserPk:  userInfo.Original.Pk,
 | 
			
		||||
	}
 | 
			
		||||
	r.pi.boundUsersMutex.Unlock()
 | 
			
		||||
	err = w.Write(response)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		r.Log().WithError(err).Warning("failed to write response")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										54
									
								
								internal/outpost/radius/handle_disconnect_request.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								internal/outpost/radius/handle_disconnect_request.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,54 @@
 | 
			
		||||
package radius
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
	"goauthentik.io/internal/outpost/flow"
 | 
			
		||||
	"goauthentik.io/internal/outpost/ldap/flags"
 | 
			
		||||
	"layeh.com/radius"
 | 
			
		||||
	"layeh.com/radius/rfc2866"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (rs *RadiusServer) Handle_DisconnectRequest(w radius.ResponseWriter, r *RadiusRequest) {
 | 
			
		||||
	session := rfc2866.AcctSessionID_GetString(r.Packet)
 | 
			
		||||
 | 
			
		||||
	sendFailResponse := func() {
 | 
			
		||||
		failResponse := r.Response(radius.CodeDisconnectACK)
 | 
			
		||||
		err := w.Write(failResponse)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			r.Log().WithError(err).Warning("failed to write response")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.pi.boundUsersMutex.Lock()
 | 
			
		||||
	var f *flags.UserFlags
 | 
			
		||||
	if ff, ok := r.pi.boundUsers[session]; !ok {
 | 
			
		||||
		r.pi.boundUsersMutex.Unlock()
 | 
			
		||||
		sendFailResponse()
 | 
			
		||||
		return
 | 
			
		||||
	} else {
 | 
			
		||||
		f = ff
 | 
			
		||||
	}
 | 
			
		||||
	r.pi.boundUsersMutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	fe := flow.NewFlowExecutor(r.Context(), r.pi.invalidationFlowSlug, rs.ac.Client.GetConfig(), log.Fields{
 | 
			
		||||
		"client":    r.RemoteAddr(),
 | 
			
		||||
		"requestId": r.ID(),
 | 
			
		||||
	})
 | 
			
		||||
	fe.SetSession(f.Session)
 | 
			
		||||
	fe.DelegateClientIP(r.RemoteAddr())
 | 
			
		||||
	fe.Params.Add("goauthentik.io/outpost/radius", "true")
 | 
			
		||||
	_, err := fe.Execute()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		r.log.WithError(err).Warning("failed to logout user")
 | 
			
		||||
		sendFailResponse()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	r.pi.boundUsersMutex.Lock()
 | 
			
		||||
	delete(r.pi.boundUsers, session)
 | 
			
		||||
	r.pi.boundUsersMutex.Unlock()
 | 
			
		||||
	response := r.Response(radius.CodeDisconnectACK)
 | 
			
		||||
	err = w.Write(response)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		r.Log().WithError(err).Warning("failed to write response")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -74,7 +74,12 @@ func (rs *RadiusServer) ServeRADIUS(w radius.ResponseWriter, r *radius.Request)
 | 
			
		||||
	}
 | 
			
		||||
	nr.pi = pi
 | 
			
		||||
 | 
			
		||||
	if nr.Code == radius.CodeAccessRequest {
 | 
			
		||||
	switch nr.Code {
 | 
			
		||||
	case radius.CodeAccessRequest:
 | 
			
		||||
		rs.Handle_AccessRequest(w, nr)
 | 
			
		||||
	case radius.CodeDisconnectRequest:
 | 
			
		||||
		rs.Handle_DisconnectRequest(w, nr)
 | 
			
		||||
	default:
 | 
			
		||||
		nr.Log().WithField("code", nr.Code.String()).Debug("Unsupported packet code")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,7 @@ import (
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
	"goauthentik.io/internal/config"
 | 
			
		||||
	"goauthentik.io/internal/outpost/ak"
 | 
			
		||||
	"goauthentik.io/internal/outpost/ldap/flags"
 | 
			
		||||
	"goauthentik.io/internal/outpost/radius/metrics"
 | 
			
		||||
 | 
			
		||||
	"layeh.com/radius"
 | 
			
		||||
@ -18,9 +19,13 @@ type ProviderInstance struct {
 | 
			
		||||
	ClientNetworks  []*net.IPNet
 | 
			
		||||
	SharedSecret    []byte
 | 
			
		||||
	MFASupport      bool
 | 
			
		||||
	boundUsersMutex *sync.RWMutex
 | 
			
		||||
	boundUsers      map[string]*flags.UserFlags
 | 
			
		||||
	providerPk      int32
 | 
			
		||||
 | 
			
		||||
	appSlug                string
 | 
			
		||||
	flowSlug string
 | 
			
		||||
	authenticationFlowSlug string
 | 
			
		||||
	invalidationFlowSlug   string
 | 
			
		||||
	s                      *RadiusServer
 | 
			
		||||
	log                    *log.Entry
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user