Compare commits
1 Commits
providers/
...
providers/
Author | SHA1 | Date | |
---|---|---|---|
739acf50f4 |
@ -35,7 +35,7 @@ type ProviderInstance struct {
|
||||
cert *tls.Certificate
|
||||
certUUID string
|
||||
outpostName string
|
||||
outpostPk int32
|
||||
providerPk int32
|
||||
searchAllowedGroups []*strfmt.UUID
|
||||
boundUsersMutex *sync.RWMutex
|
||||
boundUsers map[string]*flags.UserFlags
|
||||
|
@ -22,7 +22,7 @@ import (
|
||||
|
||||
func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance {
|
||||
for _, p := range ls.providers {
|
||||
if p.outpostPk == pk {
|
||||
if p.providerPk == pk {
|
||||
return p
|
||||
}
|
||||
}
|
||||
@ -83,7 +83,7 @@ func (ls *LDAPServer) Refresh() error {
|
||||
gidStartNumber: provider.GetGidStartNumber(),
|
||||
mfaSupport: provider.GetMfaSupport(),
|
||||
outpostName: ls.ac.Outpost.Name,
|
||||
outpostPk: provider.Pk,
|
||||
providerPk: provider.Pk,
|
||||
}
|
||||
if kp := provider.Certificate.Get(); kp != nil {
|
||||
err := ls.cs.AddKeypair(*kp)
|
||||
|
@ -6,8 +6,10 @@ import (
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
)
|
||||
|
||||
func parseCIDRs(raw string) []*net.IPNet {
|
||||
@ -29,6 +31,25 @@ func parseCIDRs(raw string) []*net.IPNet {
|
||||
return cidrs
|
||||
}
|
||||
|
||||
func (rs *RadiusServer) getCurrentProvider(pk int32) *ProviderInstance {
|
||||
for _, p := range rs.providers {
|
||||
if p.providerPk == pk {
|
||||
return p
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rs *RadiusServer) getInvalidationFlow() string {
|
||||
req, _, err := rs.ac.Client.CoreApi.CoreBrandsCurrentRetrieve(context.Background()).Execute()
|
||||
if err != nil {
|
||||
rs.log.WithError(err).Warning("failed to fetch brand config")
|
||||
return ""
|
||||
}
|
||||
flow := req.GetFlowInvalidation()
|
||||
return flow
|
||||
}
|
||||
|
||||
func (rs *RadiusServer) Refresh() error {
|
||||
outposts, _, err := rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()).Execute()
|
||||
if err != nil {
|
||||
@ -37,17 +58,33 @@ func (rs *RadiusServer) Refresh() error {
|
||||
if len(outposts.Results) < 1 {
|
||||
return errors.New("no radius provider defined")
|
||||
}
|
||||
invalidationFlow := rs.getInvalidationFlow()
|
||||
providers := make([]*ProviderInstance, len(outposts.Results))
|
||||
for idx, provider := range outposts.Results {
|
||||
logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name)
|
||||
|
||||
// Get existing instance so we can transfer boundUsers
|
||||
existing := rs.getCurrentProvider(provider.Pk)
|
||||
usersMutex := &sync.RWMutex{}
|
||||
users := make(map[string]*flags.UserFlags)
|
||||
if existing != nil {
|
||||
usersMutex = existing.boundUsersMutex
|
||||
// Shallow copy, no need to lock
|
||||
users = existing.boundUsers
|
||||
}
|
||||
|
||||
providers[idx] = &ProviderInstance{
|
||||
SharedSecret: []byte(provider.GetSharedSecret()),
|
||||
ClientNetworks: parseCIDRs(provider.GetClientNetworks()),
|
||||
MFASupport: provider.GetMfaSupport(),
|
||||
appSlug: provider.ApplicationSlug,
|
||||
flowSlug: provider.AuthFlowSlug,
|
||||
s: rs,
|
||||
log: logger,
|
||||
SharedSecret: []byte(provider.GetSharedSecret()),
|
||||
ClientNetworks: parseCIDRs(provider.GetClientNetworks()),
|
||||
MFASupport: provider.GetMfaSupport(),
|
||||
appSlug: provider.ApplicationSlug,
|
||||
authenticationFlowSlug: provider.AuthFlowSlug,
|
||||
invalidationFlowSlug: invalidationFlow,
|
||||
s: rs,
|
||||
log: logger,
|
||||
providerPk: provider.Pk,
|
||||
boundUsersMutex: usersMutex,
|
||||
boundUsers: users,
|
||||
}
|
||||
}
|
||||
rs.providers = providers
|
||||
|
@ -4,15 +4,17 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/outpost/flow"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
"goauthentik.io/internal/outpost/radius/metrics"
|
||||
"layeh.com/radius"
|
||||
"layeh.com/radius/rfc2865"
|
||||
"layeh.com/radius/rfc2866"
|
||||
)
|
||||
|
||||
func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusRequest) {
|
||||
username := rfc2865.UserName_GetString(r.Packet)
|
||||
|
||||
fe := flow.NewFlowExecutor(r.Context(), r.pi.flowSlug, r.pi.s.ac.Client.GetConfig(), log.Fields{
|
||||
fe := flow.NewFlowExecutor(r.Context(), r.pi.authenticationFlowSlug, r.pi.s.ac.Client.GetConfig(), log.Fields{
|
||||
"username": username,
|
||||
"client": r.RemoteAddr(),
|
||||
"requestId": r.ID(),
|
||||
@ -64,5 +66,28 @@ func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusR
|
||||
}).Inc()
|
||||
return
|
||||
}
|
||||
_ = w.Write(r.Response(radius.CodeAccessAccept))
|
||||
// Get user info to store in context
|
||||
userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(r.Context()).Execute()
|
||||
if err != nil {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": rs.ac.Outpost.Name,
|
||||
"type": "bind",
|
||||
"reason": "user_info_fail",
|
||||
}).Inc()
|
||||
r.Log().WithError(err).Warning("failed to get user info")
|
||||
return
|
||||
}
|
||||
|
||||
response := r.Response(radius.CodeAccessAccept)
|
||||
_ = rfc2866.AcctSessionID_SetString(response, fe.GetSession().String())
|
||||
r.pi.boundUsersMutex.Lock()
|
||||
r.pi.boundUsers[fe.GetSession().String()] = &flags.UserFlags{
|
||||
Session: fe.GetSession(),
|
||||
UserPk: userInfo.Original.Pk,
|
||||
}
|
||||
r.pi.boundUsersMutex.Unlock()
|
||||
err = w.Write(response)
|
||||
if err != nil {
|
||||
r.Log().WithError(err).Warning("failed to write response")
|
||||
}
|
||||
}
|
||||
|
54
internal/outpost/radius/handle_disconnect_request.go
Normal file
54
internal/outpost/radius/handle_disconnect_request.go
Normal file
@ -0,0 +1,54 @@
|
||||
package radius
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/outpost/flow"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
"layeh.com/radius"
|
||||
"layeh.com/radius/rfc2866"
|
||||
)
|
||||
|
||||
func (rs *RadiusServer) Handle_DisconnectRequest(w radius.ResponseWriter, r *RadiusRequest) {
|
||||
session := rfc2866.AcctSessionID_GetString(r.Packet)
|
||||
|
||||
sendFailResponse := func() {
|
||||
failResponse := r.Response(radius.CodeDisconnectACK)
|
||||
err := w.Write(failResponse)
|
||||
if err != nil {
|
||||
r.Log().WithError(err).Warning("failed to write response")
|
||||
}
|
||||
}
|
||||
|
||||
r.pi.boundUsersMutex.Lock()
|
||||
var f *flags.UserFlags
|
||||
if ff, ok := r.pi.boundUsers[session]; !ok {
|
||||
r.pi.boundUsersMutex.Unlock()
|
||||
sendFailResponse()
|
||||
return
|
||||
} else {
|
||||
f = ff
|
||||
}
|
||||
r.pi.boundUsersMutex.Unlock()
|
||||
|
||||
fe := flow.NewFlowExecutor(r.Context(), r.pi.invalidationFlowSlug, rs.ac.Client.GetConfig(), log.Fields{
|
||||
"client": r.RemoteAddr(),
|
||||
"requestId": r.ID(),
|
||||
})
|
||||
fe.SetSession(f.Session)
|
||||
fe.DelegateClientIP(r.RemoteAddr())
|
||||
fe.Params.Add("goauthentik.io/outpost/radius", "true")
|
||||
_, err := fe.Execute()
|
||||
if err != nil {
|
||||
r.log.WithError(err).Warning("failed to logout user")
|
||||
sendFailResponse()
|
||||
return
|
||||
}
|
||||
r.pi.boundUsersMutex.Lock()
|
||||
delete(r.pi.boundUsers, session)
|
||||
r.pi.boundUsersMutex.Unlock()
|
||||
response := r.Response(radius.CodeDisconnectACK)
|
||||
err = w.Write(response)
|
||||
if err != nil {
|
||||
r.Log().WithError(err).Warning("failed to write response")
|
||||
}
|
||||
}
|
@ -74,7 +74,12 @@ func (rs *RadiusServer) ServeRADIUS(w radius.ResponseWriter, r *radius.Request)
|
||||
}
|
||||
nr.pi = pi
|
||||
|
||||
if nr.Code == radius.CodeAccessRequest {
|
||||
switch nr.Code {
|
||||
case radius.CodeAccessRequest:
|
||||
rs.Handle_AccessRequest(w, nr)
|
||||
case radius.CodeDisconnectRequest:
|
||||
rs.Handle_DisconnectRequest(w, nr)
|
||||
default:
|
||||
nr.Log().WithField("code", nr.Code.String()).Debug("Unsupported packet code")
|
||||
}
|
||||
}
|
||||
|
@ -9,20 +9,25 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/config"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
"goauthentik.io/internal/outpost/radius/metrics"
|
||||
|
||||
"layeh.com/radius"
|
||||
)
|
||||
|
||||
type ProviderInstance struct {
|
||||
ClientNetworks []*net.IPNet
|
||||
SharedSecret []byte
|
||||
MFASupport bool
|
||||
ClientNetworks []*net.IPNet
|
||||
SharedSecret []byte
|
||||
MFASupport bool
|
||||
boundUsersMutex *sync.RWMutex
|
||||
boundUsers map[string]*flags.UserFlags
|
||||
providerPk int32
|
||||
|
||||
appSlug string
|
||||
flowSlug string
|
||||
s *RadiusServer
|
||||
log *log.Entry
|
||||
appSlug string
|
||||
authenticationFlowSlug string
|
||||
invalidationFlowSlug string
|
||||
s *RadiusServer
|
||||
log *log.Entry
|
||||
}
|
||||
|
||||
type RadiusServer struct {
|
||||
|
Reference in New Issue
Block a user