Compare commits

...

1 Commits

Author SHA1 Message Date
739acf50f4 providers/radius: add logout support
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2024-04-01 03:34:07 +02:00
7 changed files with 146 additions and 20 deletions

View File

@ -35,7 +35,7 @@ type ProviderInstance struct {
cert *tls.Certificate
certUUID string
outpostName string
outpostPk int32
providerPk int32
searchAllowedGroups []*strfmt.UUID
boundUsersMutex *sync.RWMutex
boundUsers map[string]*flags.UserFlags

View File

@ -22,7 +22,7 @@ import (
func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance {
for _, p := range ls.providers {
if p.outpostPk == pk {
if p.providerPk == pk {
return p
}
}
@ -83,7 +83,7 @@ func (ls *LDAPServer) Refresh() error {
gidStartNumber: provider.GetGidStartNumber(),
mfaSupport: provider.GetMfaSupport(),
outpostName: ls.ac.Outpost.Name,
outpostPk: provider.Pk,
providerPk: provider.Pk,
}
if kp := provider.Certificate.Get(); kp != nil {
err := ls.cs.AddKeypair(*kp)

View File

@ -6,8 +6,10 @@ import (
"net"
"sort"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/ldap/flags"
)
func parseCIDRs(raw string) []*net.IPNet {
@ -29,6 +31,25 @@ func parseCIDRs(raw string) []*net.IPNet {
return cidrs
}
func (rs *RadiusServer) getCurrentProvider(pk int32) *ProviderInstance {
for _, p := range rs.providers {
if p.providerPk == pk {
return p
}
}
return nil
}
func (rs *RadiusServer) getInvalidationFlow() string {
req, _, err := rs.ac.Client.CoreApi.CoreBrandsCurrentRetrieve(context.Background()).Execute()
if err != nil {
rs.log.WithError(err).Warning("failed to fetch brand config")
return ""
}
flow := req.GetFlowInvalidation()
return flow
}
func (rs *RadiusServer) Refresh() error {
outposts, _, err := rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()).Execute()
if err != nil {
@ -37,17 +58,33 @@ func (rs *RadiusServer) Refresh() error {
if len(outposts.Results) < 1 {
return errors.New("no radius provider defined")
}
invalidationFlow := rs.getInvalidationFlow()
providers := make([]*ProviderInstance, len(outposts.Results))
for idx, provider := range outposts.Results {
logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name)
// Get existing instance so we can transfer boundUsers
existing := rs.getCurrentProvider(provider.Pk)
usersMutex := &sync.RWMutex{}
users := make(map[string]*flags.UserFlags)
if existing != nil {
usersMutex = existing.boundUsersMutex
// Shallow copy, no need to lock
users = existing.boundUsers
}
providers[idx] = &ProviderInstance{
SharedSecret: []byte(provider.GetSharedSecret()),
ClientNetworks: parseCIDRs(provider.GetClientNetworks()),
MFASupport: provider.GetMfaSupport(),
appSlug: provider.ApplicationSlug,
flowSlug: provider.AuthFlowSlug,
s: rs,
log: logger,
SharedSecret: []byte(provider.GetSharedSecret()),
ClientNetworks: parseCIDRs(provider.GetClientNetworks()),
MFASupport: provider.GetMfaSupport(),
appSlug: provider.ApplicationSlug,
authenticationFlowSlug: provider.AuthFlowSlug,
invalidationFlowSlug: invalidationFlow,
s: rs,
log: logger,
providerPk: provider.Pk,
boundUsersMutex: usersMutex,
boundUsers: users,
}
}
rs.providers = providers

View File

@ -4,15 +4,17 @@ import (
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/flow"
"goauthentik.io/internal/outpost/ldap/flags"
"goauthentik.io/internal/outpost/radius/metrics"
"layeh.com/radius"
"layeh.com/radius/rfc2865"
"layeh.com/radius/rfc2866"
)
func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusRequest) {
username := rfc2865.UserName_GetString(r.Packet)
fe := flow.NewFlowExecutor(r.Context(), r.pi.flowSlug, r.pi.s.ac.Client.GetConfig(), log.Fields{
fe := flow.NewFlowExecutor(r.Context(), r.pi.authenticationFlowSlug, r.pi.s.ac.Client.GetConfig(), log.Fields{
"username": username,
"client": r.RemoteAddr(),
"requestId": r.ID(),
@ -64,5 +66,28 @@ func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusR
}).Inc()
return
}
_ = w.Write(r.Response(radius.CodeAccessAccept))
// Get user info to store in context
userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(r.Context()).Execute()
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": rs.ac.Outpost.Name,
"type": "bind",
"reason": "user_info_fail",
}).Inc()
r.Log().WithError(err).Warning("failed to get user info")
return
}
response := r.Response(radius.CodeAccessAccept)
_ = rfc2866.AcctSessionID_SetString(response, fe.GetSession().String())
r.pi.boundUsersMutex.Lock()
r.pi.boundUsers[fe.GetSession().String()] = &flags.UserFlags{
Session: fe.GetSession(),
UserPk: userInfo.Original.Pk,
}
r.pi.boundUsersMutex.Unlock()
err = w.Write(response)
if err != nil {
r.Log().WithError(err).Warning("failed to write response")
}
}

View File

@ -0,0 +1,54 @@
package radius
import (
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/flow"
"goauthentik.io/internal/outpost/ldap/flags"
"layeh.com/radius"
"layeh.com/radius/rfc2866"
)
func (rs *RadiusServer) Handle_DisconnectRequest(w radius.ResponseWriter, r *RadiusRequest) {
session := rfc2866.AcctSessionID_GetString(r.Packet)
sendFailResponse := func() {
failResponse := r.Response(radius.CodeDisconnectACK)
err := w.Write(failResponse)
if err != nil {
r.Log().WithError(err).Warning("failed to write response")
}
}
r.pi.boundUsersMutex.Lock()
var f *flags.UserFlags
if ff, ok := r.pi.boundUsers[session]; !ok {
r.pi.boundUsersMutex.Unlock()
sendFailResponse()
return
} else {
f = ff
}
r.pi.boundUsersMutex.Unlock()
fe := flow.NewFlowExecutor(r.Context(), r.pi.invalidationFlowSlug, rs.ac.Client.GetConfig(), log.Fields{
"client": r.RemoteAddr(),
"requestId": r.ID(),
})
fe.SetSession(f.Session)
fe.DelegateClientIP(r.RemoteAddr())
fe.Params.Add("goauthentik.io/outpost/radius", "true")
_, err := fe.Execute()
if err != nil {
r.log.WithError(err).Warning("failed to logout user")
sendFailResponse()
return
}
r.pi.boundUsersMutex.Lock()
delete(r.pi.boundUsers, session)
r.pi.boundUsersMutex.Unlock()
response := r.Response(radius.CodeDisconnectACK)
err = w.Write(response)
if err != nil {
r.Log().WithError(err).Warning("failed to write response")
}
}

View File

@ -74,7 +74,12 @@ func (rs *RadiusServer) ServeRADIUS(w radius.ResponseWriter, r *radius.Request)
}
nr.pi = pi
if nr.Code == radius.CodeAccessRequest {
switch nr.Code {
case radius.CodeAccessRequest:
rs.Handle_AccessRequest(w, nr)
case radius.CodeDisconnectRequest:
rs.Handle_DisconnectRequest(w, nr)
default:
nr.Log().WithField("code", nr.Code.String()).Debug("Unsupported packet code")
}
}

View File

@ -9,20 +9,25 @@ import (
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/config"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/ldap/flags"
"goauthentik.io/internal/outpost/radius/metrics"
"layeh.com/radius"
)
type ProviderInstance struct {
ClientNetworks []*net.IPNet
SharedSecret []byte
MFASupport bool
ClientNetworks []*net.IPNet
SharedSecret []byte
MFASupport bool
boundUsersMutex *sync.RWMutex
boundUsers map[string]*flags.UserFlags
providerPk int32
appSlug string
flowSlug string
s *RadiusServer
log *log.Entry
appSlug string
authenticationFlowSlug string
invalidationFlowSlug string
s *RadiusServer
log *log.Entry
}
type RadiusServer struct {