Compare commits
	
		
			1 Commits
		
	
	
		
			lib/sync/d
			...
			providers/
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 739acf50f4 | 
| @ -35,7 +35,7 @@ type ProviderInstance struct { | ||||
| 	cert                *tls.Certificate | ||||
| 	certUUID            string | ||||
| 	outpostName         string | ||||
| 	outpostPk           int32 | ||||
| 	providerPk          int32 | ||||
| 	searchAllowedGroups []*strfmt.UUID | ||||
| 	boundUsersMutex     *sync.RWMutex | ||||
| 	boundUsers          map[string]*flags.UserFlags | ||||
|  | ||||
| @ -22,7 +22,7 @@ import ( | ||||
|  | ||||
| func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance { | ||||
| 	for _, p := range ls.providers { | ||||
| 		if p.outpostPk == pk { | ||||
| 		if p.providerPk == pk { | ||||
| 			return p | ||||
| 		} | ||||
| 	} | ||||
| @ -83,7 +83,7 @@ func (ls *LDAPServer) Refresh() error { | ||||
| 			gidStartNumber:         provider.GetGidStartNumber(), | ||||
| 			mfaSupport:             provider.GetMfaSupport(), | ||||
| 			outpostName:            ls.ac.Outpost.Name, | ||||
| 			outpostPk:              provider.Pk, | ||||
| 			providerPk:             provider.Pk, | ||||
| 		} | ||||
| 		if kp := provider.Certificate.Get(); kp != nil { | ||||
| 			err := ls.cs.AddKeypair(*kp) | ||||
|  | ||||
| @ -6,8 +6,10 @@ import ( | ||||
| 	"net" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
|  | ||||
| 	log "github.com/sirupsen/logrus" | ||||
| 	"goauthentik.io/internal/outpost/ldap/flags" | ||||
| ) | ||||
|  | ||||
| func parseCIDRs(raw string) []*net.IPNet { | ||||
| @ -29,6 +31,25 @@ func parseCIDRs(raw string) []*net.IPNet { | ||||
| 	return cidrs | ||||
| } | ||||
|  | ||||
| func (rs *RadiusServer) getCurrentProvider(pk int32) *ProviderInstance { | ||||
| 	for _, p := range rs.providers { | ||||
| 		if p.providerPk == pk { | ||||
| 			return p | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (rs *RadiusServer) getInvalidationFlow() string { | ||||
| 	req, _, err := rs.ac.Client.CoreApi.CoreBrandsCurrentRetrieve(context.Background()).Execute() | ||||
| 	if err != nil { | ||||
| 		rs.log.WithError(err).Warning("failed to fetch brand config") | ||||
| 		return "" | ||||
| 	} | ||||
| 	flow := req.GetFlowInvalidation() | ||||
| 	return flow | ||||
| } | ||||
|  | ||||
| func (rs *RadiusServer) Refresh() error { | ||||
| 	outposts, _, err := rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()).Execute() | ||||
| 	if err != nil { | ||||
| @ -37,17 +58,33 @@ func (rs *RadiusServer) Refresh() error { | ||||
| 	if len(outposts.Results) < 1 { | ||||
| 		return errors.New("no radius provider defined") | ||||
| 	} | ||||
| 	invalidationFlow := rs.getInvalidationFlow() | ||||
| 	providers := make([]*ProviderInstance, len(outposts.Results)) | ||||
| 	for idx, provider := range outposts.Results { | ||||
| 		logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name) | ||||
|  | ||||
| 		// Get existing instance so we can transfer boundUsers | ||||
| 		existing := rs.getCurrentProvider(provider.Pk) | ||||
| 		usersMutex := &sync.RWMutex{} | ||||
| 		users := make(map[string]*flags.UserFlags) | ||||
| 		if existing != nil { | ||||
| 			usersMutex = existing.boundUsersMutex | ||||
| 			// Shallow copy, no need to lock | ||||
| 			users = existing.boundUsers | ||||
| 		} | ||||
|  | ||||
| 		providers[idx] = &ProviderInstance{ | ||||
| 			SharedSecret:   []byte(provider.GetSharedSecret()), | ||||
| 			ClientNetworks: parseCIDRs(provider.GetClientNetworks()), | ||||
| 			MFASupport:     provider.GetMfaSupport(), | ||||
| 			appSlug:        provider.ApplicationSlug, | ||||
| 			flowSlug:       provider.AuthFlowSlug, | ||||
| 			s:              rs, | ||||
| 			log:            logger, | ||||
| 			SharedSecret:           []byte(provider.GetSharedSecret()), | ||||
| 			ClientNetworks:         parseCIDRs(provider.GetClientNetworks()), | ||||
| 			MFASupport:             provider.GetMfaSupport(), | ||||
| 			appSlug:                provider.ApplicationSlug, | ||||
| 			authenticationFlowSlug: provider.AuthFlowSlug, | ||||
| 			invalidationFlowSlug:   invalidationFlow, | ||||
| 			s:                      rs, | ||||
| 			log:                    logger, | ||||
| 			providerPk:             provider.Pk, | ||||
| 			boundUsersMutex:        usersMutex, | ||||
| 			boundUsers:             users, | ||||
| 		} | ||||
| 	} | ||||
| 	rs.providers = providers | ||||
|  | ||||
| @ -4,15 +4,17 @@ import ( | ||||
| 	"github.com/prometheus/client_golang/prometheus" | ||||
| 	log "github.com/sirupsen/logrus" | ||||
| 	"goauthentik.io/internal/outpost/flow" | ||||
| 	"goauthentik.io/internal/outpost/ldap/flags" | ||||
| 	"goauthentik.io/internal/outpost/radius/metrics" | ||||
| 	"layeh.com/radius" | ||||
| 	"layeh.com/radius/rfc2865" | ||||
| 	"layeh.com/radius/rfc2866" | ||||
| ) | ||||
|  | ||||
| func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusRequest) { | ||||
| 	username := rfc2865.UserName_GetString(r.Packet) | ||||
|  | ||||
| 	fe := flow.NewFlowExecutor(r.Context(), r.pi.flowSlug, r.pi.s.ac.Client.GetConfig(), log.Fields{ | ||||
| 	fe := flow.NewFlowExecutor(r.Context(), r.pi.authenticationFlowSlug, r.pi.s.ac.Client.GetConfig(), log.Fields{ | ||||
| 		"username":  username, | ||||
| 		"client":    r.RemoteAddr(), | ||||
| 		"requestId": r.ID(), | ||||
| @ -64,5 +66,28 @@ func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusR | ||||
| 		}).Inc() | ||||
| 		return | ||||
| 	} | ||||
| 	_ = w.Write(r.Response(radius.CodeAccessAccept)) | ||||
| 	// Get user info to store in context | ||||
| 	userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(r.Context()).Execute() | ||||
| 	if err != nil { | ||||
| 		metrics.RequestsRejected.With(prometheus.Labels{ | ||||
| 			"outpost_name": rs.ac.Outpost.Name, | ||||
| 			"type":         "bind", | ||||
| 			"reason":       "user_info_fail", | ||||
| 		}).Inc() | ||||
| 		r.Log().WithError(err).Warning("failed to get user info") | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	response := r.Response(radius.CodeAccessAccept) | ||||
| 	_ = rfc2866.AcctSessionID_SetString(response, fe.GetSession().String()) | ||||
| 	r.pi.boundUsersMutex.Lock() | ||||
| 	r.pi.boundUsers[fe.GetSession().String()] = &flags.UserFlags{ | ||||
| 		Session: fe.GetSession(), | ||||
| 		UserPk:  userInfo.Original.Pk, | ||||
| 	} | ||||
| 	r.pi.boundUsersMutex.Unlock() | ||||
| 	err = w.Write(response) | ||||
| 	if err != nil { | ||||
| 		r.Log().WithError(err).Warning("failed to write response") | ||||
| 	} | ||||
| } | ||||
|  | ||||
							
								
								
									
										54
									
								
								internal/outpost/radius/handle_disconnect_request.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								internal/outpost/radius/handle_disconnect_request.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,54 @@ | ||||
| package radius | ||||
|  | ||||
| import ( | ||||
| 	log "github.com/sirupsen/logrus" | ||||
| 	"goauthentik.io/internal/outpost/flow" | ||||
| 	"goauthentik.io/internal/outpost/ldap/flags" | ||||
| 	"layeh.com/radius" | ||||
| 	"layeh.com/radius/rfc2866" | ||||
| ) | ||||
|  | ||||
| func (rs *RadiusServer) Handle_DisconnectRequest(w radius.ResponseWriter, r *RadiusRequest) { | ||||
| 	session := rfc2866.AcctSessionID_GetString(r.Packet) | ||||
|  | ||||
| 	sendFailResponse := func() { | ||||
| 		failResponse := r.Response(radius.CodeDisconnectACK) | ||||
| 		err := w.Write(failResponse) | ||||
| 		if err != nil { | ||||
| 			r.Log().WithError(err).Warning("failed to write response") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	r.pi.boundUsersMutex.Lock() | ||||
| 	var f *flags.UserFlags | ||||
| 	if ff, ok := r.pi.boundUsers[session]; !ok { | ||||
| 		r.pi.boundUsersMutex.Unlock() | ||||
| 		sendFailResponse() | ||||
| 		return | ||||
| 	} else { | ||||
| 		f = ff | ||||
| 	} | ||||
| 	r.pi.boundUsersMutex.Unlock() | ||||
|  | ||||
| 	fe := flow.NewFlowExecutor(r.Context(), r.pi.invalidationFlowSlug, rs.ac.Client.GetConfig(), log.Fields{ | ||||
| 		"client":    r.RemoteAddr(), | ||||
| 		"requestId": r.ID(), | ||||
| 	}) | ||||
| 	fe.SetSession(f.Session) | ||||
| 	fe.DelegateClientIP(r.RemoteAddr()) | ||||
| 	fe.Params.Add("goauthentik.io/outpost/radius", "true") | ||||
| 	_, err := fe.Execute() | ||||
| 	if err != nil { | ||||
| 		r.log.WithError(err).Warning("failed to logout user") | ||||
| 		sendFailResponse() | ||||
| 		return | ||||
| 	} | ||||
| 	r.pi.boundUsersMutex.Lock() | ||||
| 	delete(r.pi.boundUsers, session) | ||||
| 	r.pi.boundUsersMutex.Unlock() | ||||
| 	response := r.Response(radius.CodeDisconnectACK) | ||||
| 	err = w.Write(response) | ||||
| 	if err != nil { | ||||
| 		r.Log().WithError(err).Warning("failed to write response") | ||||
| 	} | ||||
| } | ||||
| @ -74,7 +74,12 @@ func (rs *RadiusServer) ServeRADIUS(w radius.ResponseWriter, r *radius.Request) | ||||
| 	} | ||||
| 	nr.pi = pi | ||||
|  | ||||
| 	if nr.Code == radius.CodeAccessRequest { | ||||
| 	switch nr.Code { | ||||
| 	case radius.CodeAccessRequest: | ||||
| 		rs.Handle_AccessRequest(w, nr) | ||||
| 	case radius.CodeDisconnectRequest: | ||||
| 		rs.Handle_DisconnectRequest(w, nr) | ||||
| 	default: | ||||
| 		nr.Log().WithField("code", nr.Code.String()).Debug("Unsupported packet code") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -9,20 +9,25 @@ import ( | ||||
| 	log "github.com/sirupsen/logrus" | ||||
| 	"goauthentik.io/internal/config" | ||||
| 	"goauthentik.io/internal/outpost/ak" | ||||
| 	"goauthentik.io/internal/outpost/ldap/flags" | ||||
| 	"goauthentik.io/internal/outpost/radius/metrics" | ||||
|  | ||||
| 	"layeh.com/radius" | ||||
| ) | ||||
|  | ||||
| type ProviderInstance struct { | ||||
| 	ClientNetworks []*net.IPNet | ||||
| 	SharedSecret   []byte | ||||
| 	MFASupport     bool | ||||
| 	ClientNetworks  []*net.IPNet | ||||
| 	SharedSecret    []byte | ||||
| 	MFASupport      bool | ||||
| 	boundUsersMutex *sync.RWMutex | ||||
| 	boundUsers      map[string]*flags.UserFlags | ||||
| 	providerPk      int32 | ||||
|  | ||||
| 	appSlug  string | ||||
| 	flowSlug string | ||||
| 	s        *RadiusServer | ||||
| 	log      *log.Entry | ||||
| 	appSlug                string | ||||
| 	authenticationFlowSlug string | ||||
| 	invalidationFlowSlug   string | ||||
| 	s                      *RadiusServer | ||||
| 	log                    *log.Entry | ||||
| } | ||||
|  | ||||
| type RadiusServer struct { | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	