Compare commits
	
		
			1 Commits
		
	
	
		
			consistent
			...
			providers/
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 739acf50f4 | 
@ -35,7 +35,7 @@ type ProviderInstance struct {
 | 
				
			|||||||
	cert                *tls.Certificate
 | 
						cert                *tls.Certificate
 | 
				
			||||||
	certUUID            string
 | 
						certUUID            string
 | 
				
			||||||
	outpostName         string
 | 
						outpostName         string
 | 
				
			||||||
	outpostPk           int32
 | 
						providerPk          int32
 | 
				
			||||||
	searchAllowedGroups []*strfmt.UUID
 | 
						searchAllowedGroups []*strfmt.UUID
 | 
				
			||||||
	boundUsersMutex     *sync.RWMutex
 | 
						boundUsersMutex     *sync.RWMutex
 | 
				
			||||||
	boundUsers          map[string]*flags.UserFlags
 | 
						boundUsers          map[string]*flags.UserFlags
 | 
				
			||||||
 | 
				
			|||||||
@ -22,7 +22,7 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance {
 | 
					func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance {
 | 
				
			||||||
	for _, p := range ls.providers {
 | 
						for _, p := range ls.providers {
 | 
				
			||||||
		if p.outpostPk == pk {
 | 
							if p.providerPk == pk {
 | 
				
			||||||
			return p
 | 
								return p
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -83,7 +83,7 @@ func (ls *LDAPServer) Refresh() error {
 | 
				
			|||||||
			gidStartNumber:         provider.GetGidStartNumber(),
 | 
								gidStartNumber:         provider.GetGidStartNumber(),
 | 
				
			||||||
			mfaSupport:             provider.GetMfaSupport(),
 | 
								mfaSupport:             provider.GetMfaSupport(),
 | 
				
			||||||
			outpostName:            ls.ac.Outpost.Name,
 | 
								outpostName:            ls.ac.Outpost.Name,
 | 
				
			||||||
			outpostPk:              provider.Pk,
 | 
								providerPk:             provider.Pk,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if kp := provider.Certificate.Get(); kp != nil {
 | 
							if kp := provider.Certificate.Get(); kp != nil {
 | 
				
			||||||
			err := ls.cs.AddKeypair(*kp)
 | 
								err := ls.cs.AddKeypair(*kp)
 | 
				
			||||||
 | 
				
			|||||||
@ -6,8 +6,10 @@ import (
 | 
				
			|||||||
	"net"
 | 
						"net"
 | 
				
			||||||
	"sort"
 | 
						"sort"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log "github.com/sirupsen/logrus"
 | 
						log "github.com/sirupsen/logrus"
 | 
				
			||||||
 | 
						"goauthentik.io/internal/outpost/ldap/flags"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func parseCIDRs(raw string) []*net.IPNet {
 | 
					func parseCIDRs(raw string) []*net.IPNet {
 | 
				
			||||||
@ -29,6 +31,25 @@ func parseCIDRs(raw string) []*net.IPNet {
 | 
				
			|||||||
	return cidrs
 | 
						return cidrs
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (rs *RadiusServer) getCurrentProvider(pk int32) *ProviderInstance {
 | 
				
			||||||
 | 
						for _, p := range rs.providers {
 | 
				
			||||||
 | 
							if p.providerPk == pk {
 | 
				
			||||||
 | 
								return p
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (rs *RadiusServer) getInvalidationFlow() string {
 | 
				
			||||||
 | 
						req, _, err := rs.ac.Client.CoreApi.CoreBrandsCurrentRetrieve(context.Background()).Execute()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							rs.log.WithError(err).Warning("failed to fetch brand config")
 | 
				
			||||||
 | 
							return ""
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						flow := req.GetFlowInvalidation()
 | 
				
			||||||
 | 
						return flow
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (rs *RadiusServer) Refresh() error {
 | 
					func (rs *RadiusServer) Refresh() error {
 | 
				
			||||||
	outposts, _, err := rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()).Execute()
 | 
						outposts, _, err := rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()).Execute()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@ -37,17 +58,33 @@ func (rs *RadiusServer) Refresh() error {
 | 
				
			|||||||
	if len(outposts.Results) < 1 {
 | 
						if len(outposts.Results) < 1 {
 | 
				
			||||||
		return errors.New("no radius provider defined")
 | 
							return errors.New("no radius provider defined")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						invalidationFlow := rs.getInvalidationFlow()
 | 
				
			||||||
	providers := make([]*ProviderInstance, len(outposts.Results))
 | 
						providers := make([]*ProviderInstance, len(outposts.Results))
 | 
				
			||||||
	for idx, provider := range outposts.Results {
 | 
						for idx, provider := range outposts.Results {
 | 
				
			||||||
		logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name)
 | 
							logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Get existing instance so we can transfer boundUsers
 | 
				
			||||||
 | 
							existing := rs.getCurrentProvider(provider.Pk)
 | 
				
			||||||
 | 
							usersMutex := &sync.RWMutex{}
 | 
				
			||||||
 | 
							users := make(map[string]*flags.UserFlags)
 | 
				
			||||||
 | 
							if existing != nil {
 | 
				
			||||||
 | 
								usersMutex = existing.boundUsersMutex
 | 
				
			||||||
 | 
								// Shallow copy, no need to lock
 | 
				
			||||||
 | 
								users = existing.boundUsers
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		providers[idx] = &ProviderInstance{
 | 
							providers[idx] = &ProviderInstance{
 | 
				
			||||||
			SharedSecret:           []byte(provider.GetSharedSecret()),
 | 
								SharedSecret:           []byte(provider.GetSharedSecret()),
 | 
				
			||||||
			ClientNetworks:         parseCIDRs(provider.GetClientNetworks()),
 | 
								ClientNetworks:         parseCIDRs(provider.GetClientNetworks()),
 | 
				
			||||||
			MFASupport:             provider.GetMfaSupport(),
 | 
								MFASupport:             provider.GetMfaSupport(),
 | 
				
			||||||
			appSlug:                provider.ApplicationSlug,
 | 
								appSlug:                provider.ApplicationSlug,
 | 
				
			||||||
			flowSlug:       provider.AuthFlowSlug,
 | 
								authenticationFlowSlug: provider.AuthFlowSlug,
 | 
				
			||||||
 | 
								invalidationFlowSlug:   invalidationFlow,
 | 
				
			||||||
			s:                      rs,
 | 
								s:                      rs,
 | 
				
			||||||
			log:                    logger,
 | 
								log:                    logger,
 | 
				
			||||||
 | 
								providerPk:             provider.Pk,
 | 
				
			||||||
 | 
								boundUsersMutex:        usersMutex,
 | 
				
			||||||
 | 
								boundUsers:             users,
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	rs.providers = providers
 | 
						rs.providers = providers
 | 
				
			||||||
 | 
				
			|||||||
@ -4,15 +4,17 @@ import (
 | 
				
			|||||||
	"github.com/prometheus/client_golang/prometheus"
 | 
						"github.com/prometheus/client_golang/prometheus"
 | 
				
			||||||
	log "github.com/sirupsen/logrus"
 | 
						log "github.com/sirupsen/logrus"
 | 
				
			||||||
	"goauthentik.io/internal/outpost/flow"
 | 
						"goauthentik.io/internal/outpost/flow"
 | 
				
			||||||
 | 
						"goauthentik.io/internal/outpost/ldap/flags"
 | 
				
			||||||
	"goauthentik.io/internal/outpost/radius/metrics"
 | 
						"goauthentik.io/internal/outpost/radius/metrics"
 | 
				
			||||||
	"layeh.com/radius"
 | 
						"layeh.com/radius"
 | 
				
			||||||
	"layeh.com/radius/rfc2865"
 | 
						"layeh.com/radius/rfc2865"
 | 
				
			||||||
 | 
						"layeh.com/radius/rfc2866"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusRequest) {
 | 
					func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusRequest) {
 | 
				
			||||||
	username := rfc2865.UserName_GetString(r.Packet)
 | 
						username := rfc2865.UserName_GetString(r.Packet)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	fe := flow.NewFlowExecutor(r.Context(), r.pi.flowSlug, r.pi.s.ac.Client.GetConfig(), log.Fields{
 | 
						fe := flow.NewFlowExecutor(r.Context(), r.pi.authenticationFlowSlug, r.pi.s.ac.Client.GetConfig(), log.Fields{
 | 
				
			||||||
		"username":  username,
 | 
							"username":  username,
 | 
				
			||||||
		"client":    r.RemoteAddr(),
 | 
							"client":    r.RemoteAddr(),
 | 
				
			||||||
		"requestId": r.ID(),
 | 
							"requestId": r.ID(),
 | 
				
			||||||
@ -64,5 +66,28 @@ func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusR
 | 
				
			|||||||
		}).Inc()
 | 
							}).Inc()
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	_ = w.Write(r.Response(radius.CodeAccessAccept))
 | 
						// Get user info to store in context
 | 
				
			||||||
 | 
						userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(r.Context()).Execute()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							metrics.RequestsRejected.With(prometheus.Labels{
 | 
				
			||||||
 | 
								"outpost_name": rs.ac.Outpost.Name,
 | 
				
			||||||
 | 
								"type":         "bind",
 | 
				
			||||||
 | 
								"reason":       "user_info_fail",
 | 
				
			||||||
 | 
							}).Inc()
 | 
				
			||||||
 | 
							r.Log().WithError(err).Warning("failed to get user info")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						response := r.Response(radius.CodeAccessAccept)
 | 
				
			||||||
 | 
						_ = rfc2866.AcctSessionID_SetString(response, fe.GetSession().String())
 | 
				
			||||||
 | 
						r.pi.boundUsersMutex.Lock()
 | 
				
			||||||
 | 
						r.pi.boundUsers[fe.GetSession().String()] = &flags.UserFlags{
 | 
				
			||||||
 | 
							Session: fe.GetSession(),
 | 
				
			||||||
 | 
							UserPk:  userInfo.Original.Pk,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						r.pi.boundUsersMutex.Unlock()
 | 
				
			||||||
 | 
						err = w.Write(response)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							r.Log().WithError(err).Warning("failed to write response")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										54
									
								
								internal/outpost/radius/handle_disconnect_request.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								internal/outpost/radius/handle_disconnect_request.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,54 @@
 | 
				
			|||||||
 | 
					package radius
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						log "github.com/sirupsen/logrus"
 | 
				
			||||||
 | 
						"goauthentik.io/internal/outpost/flow"
 | 
				
			||||||
 | 
						"goauthentik.io/internal/outpost/ldap/flags"
 | 
				
			||||||
 | 
						"layeh.com/radius"
 | 
				
			||||||
 | 
						"layeh.com/radius/rfc2866"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (rs *RadiusServer) Handle_DisconnectRequest(w radius.ResponseWriter, r *RadiusRequest) {
 | 
				
			||||||
 | 
						session := rfc2866.AcctSessionID_GetString(r.Packet)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sendFailResponse := func() {
 | 
				
			||||||
 | 
							failResponse := r.Response(radius.CodeDisconnectACK)
 | 
				
			||||||
 | 
							err := w.Write(failResponse)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								r.Log().WithError(err).Warning("failed to write response")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						r.pi.boundUsersMutex.Lock()
 | 
				
			||||||
 | 
						var f *flags.UserFlags
 | 
				
			||||||
 | 
						if ff, ok := r.pi.boundUsers[session]; !ok {
 | 
				
			||||||
 | 
							r.pi.boundUsersMutex.Unlock()
 | 
				
			||||||
 | 
							sendFailResponse()
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							f = ff
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						r.pi.boundUsersMutex.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						fe := flow.NewFlowExecutor(r.Context(), r.pi.invalidationFlowSlug, rs.ac.Client.GetConfig(), log.Fields{
 | 
				
			||||||
 | 
							"client":    r.RemoteAddr(),
 | 
				
			||||||
 | 
							"requestId": r.ID(),
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						fe.SetSession(f.Session)
 | 
				
			||||||
 | 
						fe.DelegateClientIP(r.RemoteAddr())
 | 
				
			||||||
 | 
						fe.Params.Add("goauthentik.io/outpost/radius", "true")
 | 
				
			||||||
 | 
						_, err := fe.Execute()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							r.log.WithError(err).Warning("failed to logout user")
 | 
				
			||||||
 | 
							sendFailResponse()
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						r.pi.boundUsersMutex.Lock()
 | 
				
			||||||
 | 
						delete(r.pi.boundUsers, session)
 | 
				
			||||||
 | 
						r.pi.boundUsersMutex.Unlock()
 | 
				
			||||||
 | 
						response := r.Response(radius.CodeDisconnectACK)
 | 
				
			||||||
 | 
						err = w.Write(response)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							r.Log().WithError(err).Warning("failed to write response")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -74,7 +74,12 @@ func (rs *RadiusServer) ServeRADIUS(w radius.ResponseWriter, r *radius.Request)
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	nr.pi = pi
 | 
						nr.pi = pi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if nr.Code == radius.CodeAccessRequest {
 | 
						switch nr.Code {
 | 
				
			||||||
 | 
						case radius.CodeAccessRequest:
 | 
				
			||||||
		rs.Handle_AccessRequest(w, nr)
 | 
							rs.Handle_AccessRequest(w, nr)
 | 
				
			||||||
 | 
						case radius.CodeDisconnectRequest:
 | 
				
			||||||
 | 
							rs.Handle_DisconnectRequest(w, nr)
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							nr.Log().WithField("code", nr.Code.String()).Debug("Unsupported packet code")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -9,6 +9,7 @@ import (
 | 
				
			|||||||
	log "github.com/sirupsen/logrus"
 | 
						log "github.com/sirupsen/logrus"
 | 
				
			||||||
	"goauthentik.io/internal/config"
 | 
						"goauthentik.io/internal/config"
 | 
				
			||||||
	"goauthentik.io/internal/outpost/ak"
 | 
						"goauthentik.io/internal/outpost/ak"
 | 
				
			||||||
 | 
						"goauthentik.io/internal/outpost/ldap/flags"
 | 
				
			||||||
	"goauthentik.io/internal/outpost/radius/metrics"
 | 
						"goauthentik.io/internal/outpost/radius/metrics"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"layeh.com/radius"
 | 
						"layeh.com/radius"
 | 
				
			||||||
@ -18,9 +19,13 @@ type ProviderInstance struct {
 | 
				
			|||||||
	ClientNetworks  []*net.IPNet
 | 
						ClientNetworks  []*net.IPNet
 | 
				
			||||||
	SharedSecret    []byte
 | 
						SharedSecret    []byte
 | 
				
			||||||
	MFASupport      bool
 | 
						MFASupport      bool
 | 
				
			||||||
 | 
						boundUsersMutex *sync.RWMutex
 | 
				
			||||||
 | 
						boundUsers      map[string]*flags.UserFlags
 | 
				
			||||||
 | 
						providerPk      int32
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	appSlug                string
 | 
						appSlug                string
 | 
				
			||||||
	flowSlug string
 | 
						authenticationFlowSlug string
 | 
				
			||||||
 | 
						invalidationFlowSlug   string
 | 
				
			||||||
	s                      *RadiusServer
 | 
						s                      *RadiusServer
 | 
				
			||||||
	log                    *log.Entry
 | 
						log                    *log.Entry
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user