outposts/ldap: cached bind (#2824)
* initial cached ldap bind support Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org> * add web Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org> * add docs Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org> * clean up api generation Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org> * use gh action for golangci-lint Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
@ -3,8 +3,10 @@ package ldap
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/nmcclain/ldap"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/outpost/ldap/bind"
|
||||
"goauthentik.io/internal/outpost/ldap/metrics"
|
||||
"goauthentik.io/internal/utils"
|
||||
@ -24,6 +26,16 @@ func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LD
|
||||
}).Observe(float64(span.EndTime.Sub(span.StartTime)))
|
||||
req.Log().WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Bind request")
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := recover()
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
log.WithError(err.(error)).Error("recover in bind request")
|
||||
sentry.CaptureException(err.(error))
|
||||
}()
|
||||
|
||||
for _, instance := range ls.providers {
|
||||
username, err := instance.binder.GetUsername(bindDN)
|
||||
if err == nil {
|
||||
|
@ -66,6 +66,10 @@ func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResul
|
||||
fe.Answers[flow.StagePassword] = req.BindPW
|
||||
|
||||
passed, err := fe.Execute()
|
||||
flags := flags.UserFlags{
|
||||
Session: fe.GetSession(),
|
||||
}
|
||||
db.si.SetFlags(req.BindDN, flags)
|
||||
if !passed {
|
||||
metrics.RequestsRejected.With(prometheus.Labels{
|
||||
"outpost_name": db.si.GetOutpostName(),
|
||||
@ -74,6 +78,7 @@ func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResul
|
||||
"dn": req.BindDN,
|
||||
"client": req.RemoteAddr(),
|
||||
}).Inc()
|
||||
req.Log().Info("Invalid credentials")
|
||||
return ldap.LDAPResultInvalidCredentials, nil
|
||||
}
|
||||
if err != nil {
|
||||
@ -127,10 +132,8 @@ func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResul
|
||||
return ldap.LDAPResultOperationsError, nil
|
||||
}
|
||||
cs := db.SearchAccessCheck(userInfo.User)
|
||||
flags := flags.UserFlags{
|
||||
UserPk: userInfo.User.Pk,
|
||||
CanSearch: cs != nil,
|
||||
}
|
||||
flags.UserPk = userInfo.User.Pk
|
||||
flags.CanSearch = cs != nil
|
||||
db.si.SetFlags(req.BindDN, flags)
|
||||
if flags.CanSearch {
|
||||
req.Log().WithField("group", cs).Info("Allowed access to search")
|
||||
@ -143,6 +146,9 @@ func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResul
|
||||
func (db *DirectBinder) SearchAccessCheck(user api.UserSelf) *string {
|
||||
for _, group := range user.Groups {
|
||||
for _, allowedGroup := range db.si.GetSearchAllowedGroups() {
|
||||
if allowedGroup == nil {
|
||||
continue
|
||||
}
|
||||
db.log.WithField("userGroup", group.Pk).WithField("allowedGroup", allowedGroup).Trace("Checking search access")
|
||||
if group.Pk == allowedGroup.String() {
|
||||
return &group.Name
|
||||
|
62
internal/outpost/ldap/bind/memory/memory.go
Normal file
62
internal/outpost/ldap/bind/memory/memory.go
Normal file
@ -0,0 +1,62 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
ttlcache "github.com/jellydator/ttlcache/v3"
|
||||
"github.com/nmcclain/ldap"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/outpost/ldap/bind"
|
||||
"goauthentik.io/internal/outpost/ldap/bind/direct"
|
||||
"goauthentik.io/internal/outpost/ldap/server"
|
||||
)
|
||||
|
||||
type Credentials struct {
|
||||
DN string
|
||||
Password string
|
||||
}
|
||||
|
||||
type SessionBinder struct {
|
||||
direct.DirectBinder
|
||||
si server.LDAPServerInstance
|
||||
log *log.Entry
|
||||
sessions *ttlcache.Cache[Credentials, ldap.LDAPResultCode]
|
||||
}
|
||||
|
||||
func NewSessionBinder(si server.LDAPServerInstance) *SessionBinder {
|
||||
sb := &SessionBinder{
|
||||
DirectBinder: *direct.NewDirectBinder(si),
|
||||
si: si,
|
||||
log: log.WithField("logger", "authentik.outpost.ldap.binder.session"),
|
||||
sessions: ttlcache.New(ttlcache.WithDisableTouchOnHit[Credentials, ldap.LDAPResultCode]()),
|
||||
}
|
||||
go sb.sessions.Start()
|
||||
sb.log.Info("initialised session binder")
|
||||
return sb
|
||||
}
|
||||
|
||||
func (sb *SessionBinder) Bind(username string, req *bind.Request) (ldap.LDAPResultCode, error) {
|
||||
item := sb.sessions.Get(Credentials{
|
||||
DN: req.BindDN,
|
||||
Password: req.BindPW,
|
||||
})
|
||||
if item != nil {
|
||||
sb.log.WithField("bindDN", req.BindDN).Info("authenticated from session")
|
||||
return item.Value(), nil
|
||||
}
|
||||
sb.log.Debug("No session found for user, executing flow")
|
||||
result, err := sb.DirectBinder.Bind(username, req)
|
||||
// Only cache the result if there's been an error
|
||||
if err == nil {
|
||||
flags, ok := sb.si.GetFlags(req.BindDN)
|
||||
if !ok {
|
||||
sb.log.Error("user flags not set after bind")
|
||||
return result, err
|
||||
}
|
||||
sb.sessions.Set(Credentials{
|
||||
DN: req.BindDN,
|
||||
Password: req.BindPW,
|
||||
}, result, time.Duration(flags.Session.MaxAge))
|
||||
}
|
||||
return result, err
|
||||
}
|
@ -1,9 +1,14 @@
|
||||
package flags
|
||||
|
||||
import "goauthentik.io/api/v3"
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"goauthentik.io/api/v3"
|
||||
)
|
||||
|
||||
type UserFlags struct {
|
||||
UserInfo *api.User
|
||||
UserPk int32
|
||||
CanSearch bool
|
||||
Session *http.Cookie
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/api/v3"
|
||||
directbind "goauthentik.io/internal/outpost/ldap/bind/direct"
|
||||
memorybind "goauthentik.io/internal/outpost/ldap/bind/memory"
|
||||
"goauthentik.io/internal/outpost/ldap/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
directsearch "goauthentik.io/internal/outpost/ldap/search/direct"
|
||||
@ -81,7 +82,11 @@ func (ls *LDAPServer) Refresh() error {
|
||||
} else if *provider.SearchMode.Ptr() == api.SEARCHMODEENUM_DIRECT {
|
||||
providers[idx].searcher = directsearch.NewDirectSearcher(providers[idx])
|
||||
}
|
||||
providers[idx].binder = directbind.NewDirectBinder(providers[idx])
|
||||
if *provider.BindMode.Ptr() == api.BINDMODEENUM_CACHED {
|
||||
providers[idx].binder = memorybind.NewSessionBinder(providers[idx])
|
||||
} else if *provider.BindMode.Ptr() == api.BINDMODEENUM_DIRECT {
|
||||
providers[idx].binder = directbind.NewDirectBinder(providers[idx])
|
||||
}
|
||||
}
|
||||
ls.providers = providers
|
||||
ls.log.Info("Update providers")
|
||||
|
@ -31,8 +31,8 @@ type LDAPServerInstance interface {
|
||||
|
||||
UsersForGroup(api.Group) []string
|
||||
|
||||
GetFlags(string) (flags.UserFlags, bool)
|
||||
SetFlags(string, flags.UserFlags)
|
||||
GetFlags(dn string) (flags.UserFlags, bool)
|
||||
SetFlags(dn string, flags flags.UserFlags)
|
||||
|
||||
GetBaseEntry() *ldap.Entry
|
||||
GetNeededObjects(int, string, string) (bool, bool)
|
||||
|
Reference in New Issue
Block a user