Files
authentik/internal/outpost/radius/api.go
Jens Langhammer 739acf50f4 providers/radius: add logout support
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-04-01 03:34:07 +02:00

99 lines
2.7 KiB
Go

package radius
import (
"context"
"errors"
"net"
"sort"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/ldap/flags"
)
func parseCIDRs(raw string) []*net.IPNet {
parts := strings.Split(raw, ",")
cidrs := make([]*net.IPNet, len(parts))
for i, p := range parts {
_, ipnet, err := net.ParseCIDR(strings.TrimSpace(p))
if err != nil {
log.WithError(err).WithField("cidr", p).Error("Failed to parse CIDR")
continue
}
cidrs[i] = ipnet
}
sort.Slice(cidrs, func(i, j int) bool {
_, bi := cidrs[i].Mask.Size()
_, bj := cidrs[j].Mask.Size()
return bi < bj
})
return cidrs
}
func (rs *RadiusServer) getCurrentProvider(pk int32) *ProviderInstance {
for _, p := range rs.providers {
if p.providerPk == pk {
return p
}
}
return nil
}
func (rs *RadiusServer) getInvalidationFlow() string {
req, _, err := rs.ac.Client.CoreApi.CoreBrandsCurrentRetrieve(context.Background()).Execute()
if err != nil {
rs.log.WithError(err).Warning("failed to fetch brand config")
return ""
}
flow := req.GetFlowInvalidation()
return flow
}
func (rs *RadiusServer) Refresh() error {
outposts, _, err := rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()).Execute()
if err != nil {
return err
}
if len(outposts.Results) < 1 {
return errors.New("no radius provider defined")
}
invalidationFlow := rs.getInvalidationFlow()
providers := make([]*ProviderInstance, len(outposts.Results))
for idx, provider := range outposts.Results {
logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name)
// Get existing instance so we can transfer boundUsers
existing := rs.getCurrentProvider(provider.Pk)
usersMutex := &sync.RWMutex{}
users := make(map[string]*flags.UserFlags)
if existing != nil {
usersMutex = existing.boundUsersMutex
// Shallow copy, no need to lock
users = existing.boundUsers
}
providers[idx] = &ProviderInstance{
SharedSecret: []byte(provider.GetSharedSecret()),
ClientNetworks: parseCIDRs(provider.GetClientNetworks()),
MFASupport: provider.GetMfaSupport(),
appSlug: provider.ApplicationSlug,
authenticationFlowSlug: provider.AuthFlowSlug,
invalidationFlowSlug: invalidationFlow,
s: rs,
log: logger,
providerPk: provider.Pk,
boundUsersMutex: usersMutex,
boundUsers: users,
}
}
rs.providers = providers
rs.log.Info("Update providers")
return nil
}
func (rs *RadiusServer) StartRadiusServer() error {
rs.log.WithField("listen", rs.s.Addr).Info("Starting radius server")
return rs.s.ListenAndServe()
}