keep eap state when refreshing

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens Langhammer
2025-05-16 22:59:16 +02:00
parent 50c50c4109
commit 8cf8f1e199
3 changed files with 16 additions and 17 deletions

View File

@ -1,6 +1,7 @@
"""Radius Provider""" """Radius Provider"""
from collections.abc import Iterable from collections.abc import Iterable
from django.db import models from django.db import models
from django.templatetags.static import static from django.templatetags.static import static
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -41,10 +42,7 @@ class RadiusProvider(OutpostModel, Provider):
) )
certificate = models.ForeignKey( certificate = models.ForeignKey(
CertificateKeyPair, CertificateKeyPair, on_delete=models.CASCADE, default=None, null=True
on_delete=models.CASCADE,
default=None,
null=True
) )
@property @property
@ -67,7 +65,7 @@ class RadiusProvider(OutpostModel, Provider):
return RadiusProviderSerializer return RadiusProviderSerializer
def get_required_objects(self) -> Iterable[models.Model | str]: def get_required_objects(self) -> Iterable[models.Model | str]:
required_models = [self] required_models = [self, "authentik_stages_mtls.pass_outpost_certificate"]
if self.certificate is not None: if self.certificate is not None:
required_models.append(self.certificate) required_models.append(self.certificate)
return required_models return required_models

View File

@ -42,10 +42,15 @@ func (rs *RadiusServer) Refresh() error {
if len(apiProviders) < 1 { if len(apiProviders) < 1 {
return errors.New("no radius provider defined") return errors.New("no radius provider defined")
} }
providers := make([]*ProviderInstance, len(apiProviders)) providers := make(map[int32]*ProviderInstance)
for idx, provider := range apiProviders { for _, provider := range apiProviders {
existing, ok := rs.providers[provider.Pk]
state := map[string]*eap.State{}
if ok {
state = existing.eapState
}
logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name) logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name)
providers[idx] = &ProviderInstance{ providers[provider.Pk] = &ProviderInstance{
SharedSecret: []byte(provider.GetSharedSecret()), SharedSecret: []byte(provider.GetSharedSecret()),
ClientNetworks: parseCIDRs(provider.GetClientNetworks()), ClientNetworks: parseCIDRs(provider.GetClientNetworks()),
MFASupport: provider.GetMfaSupport(), MFASupport: provider.GetMfaSupport(),
@ -55,15 +60,10 @@ func (rs *RadiusServer) Refresh() error {
providerId: provider.Pk, providerId: provider.Pk,
s: rs, s: rs,
log: logger, log: logger,
eapState: map[string]*eap.State{}, eapState: state,
} }
} }
rs.providers = providers rs.providers = providers
rs.log.Info("Update providers") rs.log.Info("Update providers")
return nil return nil
} }
func (rs *RadiusServer) StartRadiusServer() error {
rs.log.WithField("listen", rs.s.Addr).Info("Starting radius server")
return rs.s.ListenAndServe()
}

View File

@ -35,14 +35,14 @@ type RadiusServer struct {
ac *ak.APIController ac *ak.APIController
cryptoStore *ak.CryptoStore cryptoStore *ak.CryptoStore
providers []*ProviderInstance providers map[int32]*ProviderInstance
} }
func NewServer(ac *ak.APIController) ak.Outpost { func NewServer(ac *ak.APIController) ak.Outpost {
rs := &RadiusServer{ rs := &RadiusServer{
log: log.WithField("logger", "authentik.outpost.radius"), log: log.WithField("logger", "authentik.outpost.radius"),
ac: ac, ac: ac,
providers: []*ProviderInstance{}, providers: map[int32]*ProviderInstance{},
cryptoStore: ak.NewCryptoStore(ac.Client.CryptoApi), cryptoStore: ak.NewCryptoStore(ac.Client.CryptoApi),
} }
rs.s = radius.PacketServer{ rs.s = radius.PacketServer{
@ -103,7 +103,8 @@ func (rs *RadiusServer) Start() error {
}() }()
go func() { go func() {
defer wg.Done() defer wg.Done()
err := rs.StartRadiusServer() rs.log.WithField("listen", rs.s.Addr).Info("Starting radius server")
err := rs.s.ListenAndServe()
if err != nil { if err != nil {
panic(err) panic(err)
} }