Compare commits
1 Commits
openapi-ge
...
providers/
Author | SHA1 | Date | |
---|---|---|---|
739acf50f4 |
@ -35,7 +35,7 @@ type ProviderInstance struct {
|
|||||||
cert *tls.Certificate
|
cert *tls.Certificate
|
||||||
certUUID string
|
certUUID string
|
||||||
outpostName string
|
outpostName string
|
||||||
outpostPk int32
|
providerPk int32
|
||||||
searchAllowedGroups []*strfmt.UUID
|
searchAllowedGroups []*strfmt.UUID
|
||||||
boundUsersMutex *sync.RWMutex
|
boundUsersMutex *sync.RWMutex
|
||||||
boundUsers map[string]*flags.UserFlags
|
boundUsers map[string]*flags.UserFlags
|
||||||
|
@ -22,7 +22,7 @@ import (
|
|||||||
|
|
||||||
func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance {
|
func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance {
|
||||||
for _, p := range ls.providers {
|
for _, p := range ls.providers {
|
||||||
if p.outpostPk == pk {
|
if p.providerPk == pk {
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -83,7 +83,7 @@ func (ls *LDAPServer) Refresh() error {
|
|||||||
gidStartNumber: provider.GetGidStartNumber(),
|
gidStartNumber: provider.GetGidStartNumber(),
|
||||||
mfaSupport: provider.GetMfaSupport(),
|
mfaSupport: provider.GetMfaSupport(),
|
||||||
outpostName: ls.ac.Outpost.Name,
|
outpostName: ls.ac.Outpost.Name,
|
||||||
outpostPk: provider.Pk,
|
providerPk: provider.Pk,
|
||||||
}
|
}
|
||||||
if kp := provider.Certificate.Get(); kp != nil {
|
if kp := provider.Certificate.Get(); kp != nil {
|
||||||
err := ls.cs.AddKeypair(*kp)
|
err := ls.cs.AddKeypair(*kp)
|
||||||
|
@ -6,8 +6,10 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"goauthentik.io/internal/outpost/ldap/flags"
|
||||||
)
|
)
|
||||||
|
|
||||||
func parseCIDRs(raw string) []*net.IPNet {
|
func parseCIDRs(raw string) []*net.IPNet {
|
||||||
@ -29,6 +31,25 @@ func parseCIDRs(raw string) []*net.IPNet {
|
|||||||
return cidrs
|
return cidrs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rs *RadiusServer) getCurrentProvider(pk int32) *ProviderInstance {
|
||||||
|
for _, p := range rs.providers {
|
||||||
|
if p.providerPk == pk {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *RadiusServer) getInvalidationFlow() string {
|
||||||
|
req, _, err := rs.ac.Client.CoreApi.CoreBrandsCurrentRetrieve(context.Background()).Execute()
|
||||||
|
if err != nil {
|
||||||
|
rs.log.WithError(err).Warning("failed to fetch brand config")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
flow := req.GetFlowInvalidation()
|
||||||
|
return flow
|
||||||
|
}
|
||||||
|
|
||||||
func (rs *RadiusServer) Refresh() error {
|
func (rs *RadiusServer) Refresh() error {
|
||||||
outposts, _, err := rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()).Execute()
|
outposts, _, err := rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()).Execute()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -37,17 +58,33 @@ func (rs *RadiusServer) Refresh() error {
|
|||||||
if len(outposts.Results) < 1 {
|
if len(outposts.Results) < 1 {
|
||||||
return errors.New("no radius provider defined")
|
return errors.New("no radius provider defined")
|
||||||
}
|
}
|
||||||
|
invalidationFlow := rs.getInvalidationFlow()
|
||||||
providers := make([]*ProviderInstance, len(outposts.Results))
|
providers := make([]*ProviderInstance, len(outposts.Results))
|
||||||
for idx, provider := range outposts.Results {
|
for idx, provider := range outposts.Results {
|
||||||
logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name)
|
logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name)
|
||||||
|
|
||||||
|
// Get existing instance so we can transfer boundUsers
|
||||||
|
existing := rs.getCurrentProvider(provider.Pk)
|
||||||
|
usersMutex := &sync.RWMutex{}
|
||||||
|
users := make(map[string]*flags.UserFlags)
|
||||||
|
if existing != nil {
|
||||||
|
usersMutex = existing.boundUsersMutex
|
||||||
|
// Shallow copy, no need to lock
|
||||||
|
users = existing.boundUsers
|
||||||
|
}
|
||||||
|
|
||||||
providers[idx] = &ProviderInstance{
|
providers[idx] = &ProviderInstance{
|
||||||
SharedSecret: []byte(provider.GetSharedSecret()),
|
SharedSecret: []byte(provider.GetSharedSecret()),
|
||||||
ClientNetworks: parseCIDRs(provider.GetClientNetworks()),
|
ClientNetworks: parseCIDRs(provider.GetClientNetworks()),
|
||||||
MFASupport: provider.GetMfaSupport(),
|
MFASupport: provider.GetMfaSupport(),
|
||||||
appSlug: provider.ApplicationSlug,
|
appSlug: provider.ApplicationSlug,
|
||||||
flowSlug: provider.AuthFlowSlug,
|
authenticationFlowSlug: provider.AuthFlowSlug,
|
||||||
s: rs,
|
invalidationFlowSlug: invalidationFlow,
|
||||||
log: logger,
|
s: rs,
|
||||||
|
log: logger,
|
||||||
|
providerPk: provider.Pk,
|
||||||
|
boundUsersMutex: usersMutex,
|
||||||
|
boundUsers: users,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rs.providers = providers
|
rs.providers = providers
|
||||||
|
@ -4,15 +4,17 @@ import (
|
|||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"goauthentik.io/internal/outpost/flow"
|
"goauthentik.io/internal/outpost/flow"
|
||||||
|
"goauthentik.io/internal/outpost/ldap/flags"
|
||||||
"goauthentik.io/internal/outpost/radius/metrics"
|
"goauthentik.io/internal/outpost/radius/metrics"
|
||||||
"layeh.com/radius"
|
"layeh.com/radius"
|
||||||
"layeh.com/radius/rfc2865"
|
"layeh.com/radius/rfc2865"
|
||||||
|
"layeh.com/radius/rfc2866"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusRequest) {
|
func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusRequest) {
|
||||||
username := rfc2865.UserName_GetString(r.Packet)
|
username := rfc2865.UserName_GetString(r.Packet)
|
||||||
|
|
||||||
fe := flow.NewFlowExecutor(r.Context(), r.pi.flowSlug, r.pi.s.ac.Client.GetConfig(), log.Fields{
|
fe := flow.NewFlowExecutor(r.Context(), r.pi.authenticationFlowSlug, r.pi.s.ac.Client.GetConfig(), log.Fields{
|
||||||
"username": username,
|
"username": username,
|
||||||
"client": r.RemoteAddr(),
|
"client": r.RemoteAddr(),
|
||||||
"requestId": r.ID(),
|
"requestId": r.ID(),
|
||||||
@ -64,5 +66,28 @@ func (rs *RadiusServer) Handle_AccessRequest(w radius.ResponseWriter, r *RadiusR
|
|||||||
}).Inc()
|
}).Inc()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_ = w.Write(r.Response(radius.CodeAccessAccept))
|
// Get user info to store in context
|
||||||
|
userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(r.Context()).Execute()
|
||||||
|
if err != nil {
|
||||||
|
metrics.RequestsRejected.With(prometheus.Labels{
|
||||||
|
"outpost_name": rs.ac.Outpost.Name,
|
||||||
|
"type": "bind",
|
||||||
|
"reason": "user_info_fail",
|
||||||
|
}).Inc()
|
||||||
|
r.Log().WithError(err).Warning("failed to get user info")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := r.Response(radius.CodeAccessAccept)
|
||||||
|
_ = rfc2866.AcctSessionID_SetString(response, fe.GetSession().String())
|
||||||
|
r.pi.boundUsersMutex.Lock()
|
||||||
|
r.pi.boundUsers[fe.GetSession().String()] = &flags.UserFlags{
|
||||||
|
Session: fe.GetSession(),
|
||||||
|
UserPk: userInfo.Original.Pk,
|
||||||
|
}
|
||||||
|
r.pi.boundUsersMutex.Unlock()
|
||||||
|
err = w.Write(response)
|
||||||
|
if err != nil {
|
||||||
|
r.Log().WithError(err).Warning("failed to write response")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
54
internal/outpost/radius/handle_disconnect_request.go
Normal file
54
internal/outpost/radius/handle_disconnect_request.go
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
package radius
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"goauthentik.io/internal/outpost/flow"
|
||||||
|
"goauthentik.io/internal/outpost/ldap/flags"
|
||||||
|
"layeh.com/radius"
|
||||||
|
"layeh.com/radius/rfc2866"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (rs *RadiusServer) Handle_DisconnectRequest(w radius.ResponseWriter, r *RadiusRequest) {
|
||||||
|
session := rfc2866.AcctSessionID_GetString(r.Packet)
|
||||||
|
|
||||||
|
sendFailResponse := func() {
|
||||||
|
failResponse := r.Response(radius.CodeDisconnectACK)
|
||||||
|
err := w.Write(failResponse)
|
||||||
|
if err != nil {
|
||||||
|
r.Log().WithError(err).Warning("failed to write response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.pi.boundUsersMutex.Lock()
|
||||||
|
var f *flags.UserFlags
|
||||||
|
if ff, ok := r.pi.boundUsers[session]; !ok {
|
||||||
|
r.pi.boundUsersMutex.Unlock()
|
||||||
|
sendFailResponse()
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
f = ff
|
||||||
|
}
|
||||||
|
r.pi.boundUsersMutex.Unlock()
|
||||||
|
|
||||||
|
fe := flow.NewFlowExecutor(r.Context(), r.pi.invalidationFlowSlug, rs.ac.Client.GetConfig(), log.Fields{
|
||||||
|
"client": r.RemoteAddr(),
|
||||||
|
"requestId": r.ID(),
|
||||||
|
})
|
||||||
|
fe.SetSession(f.Session)
|
||||||
|
fe.DelegateClientIP(r.RemoteAddr())
|
||||||
|
fe.Params.Add("goauthentik.io/outpost/radius", "true")
|
||||||
|
_, err := fe.Execute()
|
||||||
|
if err != nil {
|
||||||
|
r.log.WithError(err).Warning("failed to logout user")
|
||||||
|
sendFailResponse()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.pi.boundUsersMutex.Lock()
|
||||||
|
delete(r.pi.boundUsers, session)
|
||||||
|
r.pi.boundUsersMutex.Unlock()
|
||||||
|
response := r.Response(radius.CodeDisconnectACK)
|
||||||
|
err = w.Write(response)
|
||||||
|
if err != nil {
|
||||||
|
r.Log().WithError(err).Warning("failed to write response")
|
||||||
|
}
|
||||||
|
}
|
@ -74,7 +74,12 @@ func (rs *RadiusServer) ServeRADIUS(w radius.ResponseWriter, r *radius.Request)
|
|||||||
}
|
}
|
||||||
nr.pi = pi
|
nr.pi = pi
|
||||||
|
|
||||||
if nr.Code == radius.CodeAccessRequest {
|
switch nr.Code {
|
||||||
|
case radius.CodeAccessRequest:
|
||||||
rs.Handle_AccessRequest(w, nr)
|
rs.Handle_AccessRequest(w, nr)
|
||||||
|
case radius.CodeDisconnectRequest:
|
||||||
|
rs.Handle_DisconnectRequest(w, nr)
|
||||||
|
default:
|
||||||
|
nr.Log().WithField("code", nr.Code.String()).Debug("Unsupported packet code")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -9,20 +9,25 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"goauthentik.io/internal/config"
|
"goauthentik.io/internal/config"
|
||||||
"goauthentik.io/internal/outpost/ak"
|
"goauthentik.io/internal/outpost/ak"
|
||||||
|
"goauthentik.io/internal/outpost/ldap/flags"
|
||||||
"goauthentik.io/internal/outpost/radius/metrics"
|
"goauthentik.io/internal/outpost/radius/metrics"
|
||||||
|
|
||||||
"layeh.com/radius"
|
"layeh.com/radius"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ProviderInstance struct {
|
type ProviderInstance struct {
|
||||||
ClientNetworks []*net.IPNet
|
ClientNetworks []*net.IPNet
|
||||||
SharedSecret []byte
|
SharedSecret []byte
|
||||||
MFASupport bool
|
MFASupport bool
|
||||||
|
boundUsersMutex *sync.RWMutex
|
||||||
|
boundUsers map[string]*flags.UserFlags
|
||||||
|
providerPk int32
|
||||||
|
|
||||||
appSlug string
|
appSlug string
|
||||||
flowSlug string
|
authenticationFlowSlug string
|
||||||
s *RadiusServer
|
invalidationFlowSlug string
|
||||||
log *log.Entry
|
s *RadiusServer
|
||||||
|
log *log.Entry
|
||||||
}
|
}
|
||||||
|
|
||||||
type RadiusServer struct {
|
type RadiusServer struct {
|
||||||
|
Reference in New Issue
Block a user