Compare commits
	
		
			1 Commits
		
	
	
		
			next
			...
			providers/
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 739acf50f4 | 
| @ -35,7 +35,7 @@ type ProviderInstance struct { | |||||||
| 	cert                *tls.Certificate | 	cert                *tls.Certificate | ||||||
| 	certUUID            string | 	certUUID            string | ||||||
| 	outpostName         string | 	outpostName         string | ||||||
| 	outpostPk           int32 | 	providerPk          int32 | ||||||
| 	searchAllowedGroups []*strfmt.UUID | 	searchAllowedGroups []*strfmt.UUID | ||||||
| 	boundUsersMutex     *sync.RWMutex | 	boundUsersMutex     *sync.RWMutex | ||||||
| 	boundUsers          map[string]*flags.UserFlags | 	boundUsers          map[string]*flags.UserFlags | ||||||
|  | |||||||
| @ -22,7 +22,7 @@ import ( | |||||||
|  |  | ||||||
| func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance { | func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance { | ||||||
| 	for _, p := range ls.providers { | 	for _, p := range ls.providers { | ||||||
| 		if p.outpostPk == pk { | 		if p.providerPk == pk { | ||||||
| 			return p | 			return p | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| @ -83,7 +83,7 @@ func (ls *LDAPServer) Refresh() error { | |||||||
| 			gidStartNumber:         provider.GetGidStartNumber(), | 			gidStartNumber:         provider.GetGidStartNumber(), | ||||||
| 			mfaSupport:             provider.GetMfaSupport(), | 			mfaSupport:             provider.GetMfaSupport(), | ||||||
| 			outpostName:            ls.ac.Outpost.Name, | 			outpostName:            ls.ac.Outpost.Name, | ||||||
| 			outpostPk:              provider.Pk, | 			providerPk:             provider.Pk, | ||||||
| 		} | 		} | ||||||
| 		if kp := provider.Certificate.Get(); kp != nil { | 		if kp := provider.Certificate.Get(); kp != nil { | ||||||
| 			err := ls.cs.AddKeypair(*kp) | 			err := ls.cs.AddKeypair(*kp) | ||||||
|  | |||||||
| @ -6,8 +6,10 @@ import ( | |||||||
| 	"net" | 	"net" | ||||||
| 	"sort" | 	"sort" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"sync" | ||||||
|  |  | ||||||
| 	log "github.com/sirupsen/logrus" | 	log "github.com/sirupsen/logrus" | ||||||
|  | 	"goauthentik.io/internal/outpost/ldap/flags" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func parseCIDRs(raw string) []*net.IPNet { | func parseCIDRs(raw string) []*net.IPNet { | ||||||
| @ -29,6 +31,25 @@ func parseCIDRs(raw string) []*net.IPNet { | |||||||
| 	return cidrs | 	return cidrs | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (rs *RadiusServer) getCurrentProvider(pk int32) *ProviderInstance { | ||||||
|  | 	for _, p := range rs.providers { | ||||||
|  | 		if p.providerPk == pk { | ||||||
|  | 			return p | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (rs *RadiusServer) getInvalidationFlow() string { | ||||||
|  | 	req, _, err := rs.ac.Client.CoreApi.CoreBrandsCurrentRetrieve(context.Background()).Execute() | ||||||
|  | 	if err != nil { | ||||||
|  | 		rs.log.WithError(err).Warning("failed to fetch brand config") | ||||||
|  | 		return "" | ||||||
|  | 	} | ||||||
|  | 	flow := req.GetFlowInvalidation() | ||||||
|  | 	return flow | ||||||
|  | } | ||||||
|  |  | ||||||
| func (rs *RadiusServer) Refresh() error { | func (rs *RadiusServer) Refresh() error { | ||||||
| 	outposts, _, err := rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()).Execute() | 	outposts, _, err := rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()).Execute() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @ -37,17 +58,33 @@ func (rs *RadiusServer) Refresh() error { | |||||||
| 	if len(outposts.Results) < 1 { | 	if len(outposts.Results) < 1 { | ||||||
| 		return errors.New("no radius provider defined") | 		return errors.New("no radius provider defined") | ||||||
| 	} | 	} | ||||||
|  | 	invalidationFlow := rs.getInvalidationFlow() | ||||||
| 	providers := make([]*ProviderInstance, len(outposts.Results)) | 	providers := make([]*ProviderInstance, len(outposts.Results)) | ||||||
| 	for idx, provider := range outposts.Results { | 	for idx, provider := range outposts.Results { | ||||||
| 		logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name) | 		logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name) | ||||||
|  |  | ||||||
|  | 		// Get existing instance so we can transfer boundUsers | ||||||
|  | 		existing := rs.getCurrentProvider(provider.Pk) | ||||||
|  | 		usersMutex := &sync.RWMutex{} | ||||||
|  | 		users := make(map[string]*flags.UserFlags) | ||||||
|  | 		if existing != nil { | ||||||
|  | 			usersMutex = existing.boundUsersMutex | ||||||
|  | 			// Shallow copy, no need to lock | ||||||
|  | 			users = existing.boundUsers | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		providers[idx] = &ProviderInstance{ | 		providers[idx] = &ProviderInstance{ | ||||||
| 			SharedSecret:   []byte(provider.GetSharedSecret()), | 			SharedSecret:           []byte(provider.GetSharedSecret()), | ||||||
| 			ClientNetworks: parseCIDRs(provider.GetClientNetworks()), | 			ClientNetworks:         parseCIDRs(provider.GetClientNetworks()), | ||||||
| 			MFASupport:     provider.GetMfaSupport(), | 			MFASupport:             provider.GetMfaSupport(), | ||||||
| 			appSlug:        provider.ApplicationSlug, | 			appSlug:                provider.ApplicationSlug, | ||||||
| 			flowSlug:       provider.AuthFlowSlug, | 			authenticationFlowSlug: provider.AuthFlowSlug, | ||||||
| 			s:              rs, | 			invalidationFlowSlug:   invalidationFlow, | ||||||
| 			log:            logger, | 			s:                      rs, | ||||||
|  | 			log:                    logger, | ||||||
|  | 			providerPk:             provider.Pk, | ||||||
|  | 			boundUsersMutex:        usersMutex, | ||||||
|  | 			boundUsers:             users, | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	rs.providers = providers | 	rs.providers = providers | ||||||
|  | |||||||
| @ -4,15 +4,17 @@ import ( | |||||||
| 	"github.com/prometheus/client_golang/prometheus" | 	"github.com/prometheus/client_golang/prometheus" | ||||||
| 	log "github.com/sirupsen/logrus" | 	log "github.com/sirupsen/logrus" | ||||||
| 	"goauthentik.io/internal/outpost/flow" | 	"goauthentik.io/internal/outpost/flow" | ||||||
|  | 	"goauthentik.io/internal/outpost/ldap/flags" | ||||||
| 	"goauthentik.io/internal/outpost/radius/metrics" | 	"goauthentik.io/internal/outpost/radius/metrics" | ||||||
| 	"layeh.com/radius" | 	"layeh.com/radius" | ||||||
| 	"layeh.com/radius/rfc2865" | 	"layeh.com/radius/rfc2865" | ||||||
|  | 	"layeh.com/radius/rfc2866" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusRequest) { | func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusRequest) { | ||||||
| 	username := rfc2865.UserName_GetString(r.Packet) | 	username := rfc2865.UserName_GetString(r.Packet) | ||||||
|  |  | ||||||
| 	fe := flow.NewFlowExecutor(r.Context(), r.pi.flowSlug, r.pi.s.ac.Client.GetConfig(), log.Fields{ | 	fe := flow.NewFlowExecutor(r.Context(), r.pi.authenticationFlowSlug, r.pi.s.ac.Client.GetConfig(), log.Fields{ | ||||||
| 		"username":  username, | 		"username":  username, | ||||||
| 		"client":    r.RemoteAddr(), | 		"client":    r.RemoteAddr(), | ||||||
| 		"requestId": r.ID(), | 		"requestId": r.ID(), | ||||||
| @ -64,5 +66,28 @@ func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusR | |||||||
| 		}).Inc() | 		}).Inc() | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	_ = w.Write(r.Response(radius.CodeAccessAccept)) | 	// Get user info to store in context | ||||||
|  | 	userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(r.Context()).Execute() | ||||||
|  | 	if err != nil { | ||||||
|  | 		metrics.RequestsRejected.With(prometheus.Labels{ | ||||||
|  | 			"outpost_name": rs.ac.Outpost.Name, | ||||||
|  | 			"type":         "bind", | ||||||
|  | 			"reason":       "user_info_fail", | ||||||
|  | 		}).Inc() | ||||||
|  | 		r.Log().WithError(err).Warning("failed to get user info") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	response := r.Response(radius.CodeAccessAccept) | ||||||
|  | 	_ = rfc2866.AcctSessionID_SetString(response, fe.GetSession().String()) | ||||||
|  | 	r.pi.boundUsersMutex.Lock() | ||||||
|  | 	r.pi.boundUsers[fe.GetSession().String()] = &flags.UserFlags{ | ||||||
|  | 		Session: fe.GetSession(), | ||||||
|  | 		UserPk:  userInfo.Original.Pk, | ||||||
|  | 	} | ||||||
|  | 	r.pi.boundUsersMutex.Unlock() | ||||||
|  | 	err = w.Write(response) | ||||||
|  | 	if err != nil { | ||||||
|  | 		r.Log().WithError(err).Warning("failed to write response") | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  | |||||||
							
								
								
									
										54
									
								
								internal/outpost/radius/handle_disconnect_request.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								internal/outpost/radius/handle_disconnect_request.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,54 @@ | |||||||
|  | package radius | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	log "github.com/sirupsen/logrus" | ||||||
|  | 	"goauthentik.io/internal/outpost/flow" | ||||||
|  | 	"goauthentik.io/internal/outpost/ldap/flags" | ||||||
|  | 	"layeh.com/radius" | ||||||
|  | 	"layeh.com/radius/rfc2866" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func (rs *RadiusServer) Handle_DisconnectRequest(w radius.ResponseWriter, r *RadiusRequest) { | ||||||
|  | 	session := rfc2866.AcctSessionID_GetString(r.Packet) | ||||||
|  |  | ||||||
|  | 	sendFailResponse := func() { | ||||||
|  | 		failResponse := r.Response(radius.CodeDisconnectACK) | ||||||
|  | 		err := w.Write(failResponse) | ||||||
|  | 		if err != nil { | ||||||
|  | 			r.Log().WithError(err).Warning("failed to write response") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	r.pi.boundUsersMutex.Lock() | ||||||
|  | 	var f *flags.UserFlags | ||||||
|  | 	if ff, ok := r.pi.boundUsers[session]; !ok { | ||||||
|  | 		r.pi.boundUsersMutex.Unlock() | ||||||
|  | 		sendFailResponse() | ||||||
|  | 		return | ||||||
|  | 	} else { | ||||||
|  | 		f = ff | ||||||
|  | 	} | ||||||
|  | 	r.pi.boundUsersMutex.Unlock() | ||||||
|  |  | ||||||
|  | 	fe := flow.NewFlowExecutor(r.Context(), r.pi.invalidationFlowSlug, rs.ac.Client.GetConfig(), log.Fields{ | ||||||
|  | 		"client":    r.RemoteAddr(), | ||||||
|  | 		"requestId": r.ID(), | ||||||
|  | 	}) | ||||||
|  | 	fe.SetSession(f.Session) | ||||||
|  | 	fe.DelegateClientIP(r.RemoteAddr()) | ||||||
|  | 	fe.Params.Add("goauthentik.io/outpost/radius", "true") | ||||||
|  | 	_, err := fe.Execute() | ||||||
|  | 	if err != nil { | ||||||
|  | 		r.log.WithError(err).Warning("failed to logout user") | ||||||
|  | 		sendFailResponse() | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	r.pi.boundUsersMutex.Lock() | ||||||
|  | 	delete(r.pi.boundUsers, session) | ||||||
|  | 	r.pi.boundUsersMutex.Unlock() | ||||||
|  | 	response := r.Response(radius.CodeDisconnectACK) | ||||||
|  | 	err = w.Write(response) | ||||||
|  | 	if err != nil { | ||||||
|  | 		r.Log().WithError(err).Warning("failed to write response") | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @ -74,7 +74,12 @@ func (rs *RadiusServer) ServeRADIUS(w radius.ResponseWriter, r *radius.Request) | |||||||
| 	} | 	} | ||||||
| 	nr.pi = pi | 	nr.pi = pi | ||||||
|  |  | ||||||
| 	if nr.Code == radius.CodeAccessRequest { | 	switch nr.Code { | ||||||
|  | 	case radius.CodeAccessRequest: | ||||||
| 		rs.Handle_AccessRequest(w, nr) | 		rs.Handle_AccessRequest(w, nr) | ||||||
|  | 	case radius.CodeDisconnectRequest: | ||||||
|  | 		rs.Handle_DisconnectRequest(w, nr) | ||||||
|  | 	default: | ||||||
|  | 		nr.Log().WithField("code", nr.Code.String()).Debug("Unsupported packet code") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -9,20 +9,25 @@ import ( | |||||||
| 	log "github.com/sirupsen/logrus" | 	log "github.com/sirupsen/logrus" | ||||||
| 	"goauthentik.io/internal/config" | 	"goauthentik.io/internal/config" | ||||||
| 	"goauthentik.io/internal/outpost/ak" | 	"goauthentik.io/internal/outpost/ak" | ||||||
|  | 	"goauthentik.io/internal/outpost/ldap/flags" | ||||||
| 	"goauthentik.io/internal/outpost/radius/metrics" | 	"goauthentik.io/internal/outpost/radius/metrics" | ||||||
|  |  | ||||||
| 	"layeh.com/radius" | 	"layeh.com/radius" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type ProviderInstance struct { | type ProviderInstance struct { | ||||||
| 	ClientNetworks []*net.IPNet | 	ClientNetworks  []*net.IPNet | ||||||
| 	SharedSecret   []byte | 	SharedSecret    []byte | ||||||
| 	MFASupport     bool | 	MFASupport      bool | ||||||
|  | 	boundUsersMutex *sync.RWMutex | ||||||
|  | 	boundUsers      map[string]*flags.UserFlags | ||||||
|  | 	providerPk      int32 | ||||||
|  |  | ||||||
| 	appSlug  string | 	appSlug                string | ||||||
| 	flowSlug string | 	authenticationFlowSlug string | ||||||
| 	s        *RadiusServer | 	invalidationFlowSlug   string | ||||||
| 	log      *log.Entry | 	s                      *RadiusServer | ||||||
|  | 	log                    *log.Entry | ||||||
| } | } | ||||||
|  |  | ||||||
| type RadiusServer struct { | type RadiusServer struct { | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	