keep eap state when refreshing
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -1,6 +1,7 @@
|
|||||||
"""Radius Provider"""
|
"""Radius Provider"""
|
||||||
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.templatetags.static import static
|
from django.templatetags.static import static
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
@ -41,10 +42,7 @@ class RadiusProvider(OutpostModel, Provider):
|
|||||||
)
|
)
|
||||||
|
|
||||||
certificate = models.ForeignKey(
|
certificate = models.ForeignKey(
|
||||||
CertificateKeyPair,
|
CertificateKeyPair, on_delete=models.CASCADE, default=None, null=True
|
||||||
on_delete=models.CASCADE,
|
|
||||||
default=None,
|
|
||||||
null=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -67,7 +65,7 @@ class RadiusProvider(OutpostModel, Provider):
|
|||||||
return RadiusProviderSerializer
|
return RadiusProviderSerializer
|
||||||
|
|
||||||
def get_required_objects(self) -> Iterable[models.Model | str]:
|
def get_required_objects(self) -> Iterable[models.Model | str]:
|
||||||
required_models = [self]
|
required_models = [self, "authentik_stages_mtls.pass_outpost_certificate"]
|
||||||
if self.certificate is not None:
|
if self.certificate is not None:
|
||||||
required_models.append(self.certificate)
|
required_models.append(self.certificate)
|
||||||
return required_models
|
return required_models
|
||||||
|
@ -42,10 +42,15 @@ func (rs *RadiusServer) Refresh() error {
|
|||||||
if len(apiProviders) < 1 {
|
if len(apiProviders) < 1 {
|
||||||
return errors.New("no radius provider defined")
|
return errors.New("no radius provider defined")
|
||||||
}
|
}
|
||||||
providers := make([]*ProviderInstance, len(apiProviders))
|
providers := make(map[int32]*ProviderInstance)
|
||||||
for idx, provider := range apiProviders {
|
for _, provider := range apiProviders {
|
||||||
|
existing, ok := rs.providers[provider.Pk]
|
||||||
|
state := map[string]*eap.State{}
|
||||||
|
if ok {
|
||||||
|
state = existing.eapState
|
||||||
|
}
|
||||||
logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name)
|
logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name)
|
||||||
providers[idx] = &ProviderInstance{
|
providers[provider.Pk] = &ProviderInstance{
|
||||||
SharedSecret: []byte(provider.GetSharedSecret()),
|
SharedSecret: []byte(provider.GetSharedSecret()),
|
||||||
ClientNetworks: parseCIDRs(provider.GetClientNetworks()),
|
ClientNetworks: parseCIDRs(provider.GetClientNetworks()),
|
||||||
MFASupport: provider.GetMfaSupport(),
|
MFASupport: provider.GetMfaSupport(),
|
||||||
@ -55,15 +60,10 @@ func (rs *RadiusServer) Refresh() error {
|
|||||||
providerId: provider.Pk,
|
providerId: provider.Pk,
|
||||||
s: rs,
|
s: rs,
|
||||||
log: logger,
|
log: logger,
|
||||||
eapState: map[string]*eap.State{},
|
eapState: state,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rs.providers = providers
|
rs.providers = providers
|
||||||
rs.log.Info("Update providers")
|
rs.log.Info("Update providers")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rs *RadiusServer) StartRadiusServer() error {
|
|
||||||
rs.log.WithField("listen", rs.s.Addr).Info("Starting radius server")
|
|
||||||
return rs.s.ListenAndServe()
|
|
||||||
}
|
|
||||||
|
@ -35,14 +35,14 @@ type RadiusServer struct {
|
|||||||
ac *ak.APIController
|
ac *ak.APIController
|
||||||
cryptoStore *ak.CryptoStore
|
cryptoStore *ak.CryptoStore
|
||||||
|
|
||||||
providers []*ProviderInstance
|
providers map[int32]*ProviderInstance
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(ac *ak.APIController) ak.Outpost {
|
func NewServer(ac *ak.APIController) ak.Outpost {
|
||||||
rs := &RadiusServer{
|
rs := &RadiusServer{
|
||||||
log: log.WithField("logger", "authentik.outpost.radius"),
|
log: log.WithField("logger", "authentik.outpost.radius"),
|
||||||
ac: ac,
|
ac: ac,
|
||||||
providers: []*ProviderInstance{},
|
providers: map[int32]*ProviderInstance{},
|
||||||
cryptoStore: ak.NewCryptoStore(ac.Client.CryptoApi),
|
cryptoStore: ak.NewCryptoStore(ac.Client.CryptoApi),
|
||||||
}
|
}
|
||||||
rs.s = radius.PacketServer{
|
rs.s = radius.PacketServer{
|
||||||
@ -103,7 +103,8 @@ func (rs *RadiusServer) Start() error {
|
|||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
err := rs.StartRadiusServer()
|
rs.log.WithField("listen", rs.s.Addr).Info("Starting radius server")
|
||||||
|
err := rs.s.ListenAndServe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user