providers/ldap: memory Query (#1681)

* outposts/ldap: modularise ldap outpost, to allow different searchers and binders

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>

* outposts/ldap: add basic in-memory searcher

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>

* providers/ldap: add search mode field

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>

* outpost: add search mode field

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens L
2021-11-05 10:37:30 +01:00
committed by GitHub
parent 8de13d3f67
commit 5a8c66d325
37 changed files with 1293 additions and 639 deletions

View File

@ -1,23 +0,0 @@
package ldap
import "crypto/tls"
func (ls *LDAPServer) getCertificates(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
if len(ls.providers) == 1 {
if ls.providers[0].cert != nil {
ls.log.WithField("server-name", info.ServerName).Debug("We only have a single provider, using their cert")
return ls.providers[0].cert, nil
}
}
for _, provider := range ls.providers {
if provider.tlsServerName == &info.ServerName {
if provider.cert == nil {
ls.log.WithField("server-name", info.ServerName).Debug("Handler does not have a certificate")
return ls.defaultCert, nil
}
return provider.cert, nil
}
}
ls.log.WithField("server-name", info.ServerName).Debug("Fallback to default cert")
return ls.defaultCert, nil
}

View File

@ -1,44 +1,18 @@
package ldap
import (
"context"
"net"
"strings"
"github.com/getsentry/sentry-go"
"github.com/google/uuid"
"github.com/nmcclain/ldap"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/ldap/bind"
"goauthentik.io/internal/outpost/ldap/metrics"
"goauthentik.io/internal/utils"
)
type BindRequest struct {
BindDN string
BindPW string
id string
conn net.Conn
log *log.Entry
ctx context.Context
}
func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LDAPResultCode, error) {
span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.bind",
sentry.TransactionName("authentik.providers.ldap.bind"))
rid := uuid.New().String()
span.SetTag("request_uid", rid)
span.SetTag("user.username", bindDN)
req, span := bind.NewRequest(bindDN, bindPW, conn)
bindDN = strings.ToLower(bindDN)
req := BindRequest{
BindDN: bindDN,
BindPW: bindPW,
conn: conn,
log: ls.log.WithField("bindDN", bindDN).WithField("requestId", rid).WithField("client", utils.GetIP(conn.RemoteAddr())),
id: rid,
ctx: span.Context(),
}
defer func() {
span.Finish()
metrics.Requests.With(prometheus.Labels{
@ -46,19 +20,19 @@ func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LD
"type": "bind",
"filter": "",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
"client": req.RemoteAddr(),
}).Observe(float64(span.EndTime.Sub(span.StartTime)))
req.log.WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Bind request")
req.Log().WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Bind request")
}()
for _, instance := range ls.providers {
username, err := instance.getUsername(bindDN)
username, err := instance.binder.GetUsername(bindDN)
if err == nil {
return instance.Bind(username, req)
return instance.binder.Bind(username, req)
} else {
req.log.WithError(err).Debug("Username not for instance")
req.Log().WithError(err).Debug("Username not for instance")
}
}
req.log.WithField("request", "bind").Warning("No provider found for request")
req.Log().WithField("request", "bind").Warning("No provider found for request")
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ls.ac.Outpost.Name,
"type": "bind",
@ -68,10 +42,3 @@ func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LD
}).Inc()
return ldap.LDAPResultOperationsError, nil
}
func (ls *LDAPServer) TimerFlowCacheExpiry() {
for _, p := range ls.providers {
ls.log.WithField("flow", p.flowSlug).Debug("Pre-heating flow cache")
p.TimerFlowCacheExpiry()
}
}

View File

@ -0,0 +1,9 @@
package bind
import "github.com/nmcclain/ldap"
type Binder interface {
GetUsername(string) (string, error)
Bind(username string, req *Request) (ldap.LDAPResultCode, error)
TimerFlowCacheExpiry()
}

View File

@ -1,4 +1,4 @@
package ldap
package direct
import (
"context"
@ -12,14 +12,30 @@ import (
log "github.com/sirupsen/logrus"
"goauthentik.io/api"
"goauthentik.io/internal/outpost"
"goauthentik.io/internal/outpost/ldap/bind"
"goauthentik.io/internal/outpost/ldap/flags"
"goauthentik.io/internal/outpost/ldap/metrics"
"goauthentik.io/internal/utils"
"goauthentik.io/internal/outpost/ldap/server"
)
const ContextUserKey = "ak_user"
func (pi *ProviderInstance) getUsername(dn string) (string, error) {
if !strings.HasSuffix(strings.ToLower(dn), strings.ToLower(pi.BaseDN)) {
type DirectBinder struct {
si server.LDAPServerInstance
log *log.Entry
}
func NewDirectBinder(si server.LDAPServerInstance) *DirectBinder {
db := &DirectBinder{
si: si,
log: log.WithField("logger", "authentik.outpost.ldap.binder.direct"),
}
db.log.Info("initialised direct binder")
return db
}
func (db *DirectBinder) GetUsername(dn string) (string, error) {
if !strings.HasSuffix(strings.ToLower(dn), strings.ToLower(db.si.GetBaseDN())) {
return "", errors.New("invalid base DN")
}
dns, err := goldap.ParseDN(dn)
@ -36,13 +52,13 @@ func (pi *ProviderInstance) getUsername(dn string) (string, error) {
return "", errors.New("failed to find cn")
}
func (pi *ProviderInstance) Bind(username string, req BindRequest) (ldap.LDAPResultCode, error) {
fe := outpost.NewFlowExecutor(req.ctx, pi.flowSlug, pi.s.ac.Client.GetConfig(), log.Fields{
func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResultCode, error) {
fe := outpost.NewFlowExecutor(req.Context(), db.si.GetFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{
"bindDN": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
"requestId": req.id,
"client": req.RemoteAddr(),
"requestId": req.ID(),
})
fe.DelegateClientIP(req.conn.RemoteAddr())
fe.DelegateClientIP(req.RemoteAddr())
fe.Params.Add("goauthentik.io/outpost/ldap", "true")
fe.Answers[outpost.StageIdentification] = username
@ -51,83 +67,82 @@ func (pi *ProviderInstance) Bind(username string, req BindRequest) (ldap.LDAPRes
passed, err := fe.Execute()
if !passed {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "invalid_credentials",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
"client": req.RemoteAddr(),
}).Inc()
return ldap.LDAPResultInvalidCredentials, nil
}
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "flow_error",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
"client": req.RemoteAddr(),
}).Inc()
req.log.WithError(err).Warning("failed to execute flow")
req.Log().WithError(err).Warning("failed to execute flow")
return ldap.LDAPResultOperationsError, nil
}
access, err := fe.CheckApplicationAccess(pi.appSlug)
access, err := fe.CheckApplicationAccess(db.si.GetAppSlug())
if !access {
req.log.Info("Access denied for user")
req.Log().Info("Access denied for user")
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "access_denied",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
"client": req.RemoteAddr(),
}).Inc()
return ldap.LDAPResultInsufficientAccessRights, nil
}
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "access_check_fail",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
"client": req.RemoteAddr(),
}).Inc()
req.log.WithError(err).Warning("failed to check access")
req.Log().WithError(err).Warning("failed to check access")
return ldap.LDAPResultOperationsError, nil
}
req.log.Info("User has access")
uisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.bind.user_info")
req.Log().Info("User has access")
uisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.bind.user_info")
// Get user info to store in context
userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(context.Background()).Execute()
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "user_info_fail",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
"client": req.RemoteAddr(),
}).Inc()
req.log.WithError(err).Warning("failed to get user info")
req.Log().WithError(err).Warning("failed to get user info")
return ldap.LDAPResultOperationsError, nil
}
pi.boundUsersMutex.Lock()
cs := pi.SearchAccessCheck(userInfo.User)
pi.boundUsers[req.BindDN] = UserFlags{
cs := db.SearchAccessCheck(userInfo.User)
flags := flags.UserFlags{
UserPk: userInfo.User.Pk,
CanSearch: cs != nil,
}
if pi.boundUsers[req.BindDN].CanSearch {
req.log.WithField("group", cs).Info("Allowed access to search")
db.si.SetFlags(req.BindDN, flags)
if flags.CanSearch {
req.Log().WithField("group", cs).Info("Allowed access to search")
}
uisp.Finish()
defer pi.boundUsersMutex.Unlock()
return ldap.LDAPResultSuccess, nil
}
// SearchAccessCheck Check if the current user is allowed to search
func (pi *ProviderInstance) SearchAccessCheck(user api.UserSelf) *string {
func (db *DirectBinder) SearchAccessCheck(user api.UserSelf) *string {
for _, group := range user.Groups {
for _, allowedGroup := range pi.searchAllowedGroups {
pi.log.WithField("userGroup", group.Pk).WithField("allowedGroup", allowedGroup).Trace("Checking search access")
for _, allowedGroup := range db.si.GetSearchAllowedGroups() {
db.log.WithField("userGroup", group.Pk).WithField("allowedGroup", allowedGroup).Trace("Checking search access")
if group.Pk == allowedGroup.String() {
return &group.Name
}
@ -136,13 +151,13 @@ func (pi *ProviderInstance) SearchAccessCheck(user api.UserSelf) *string {
return nil
}
func (pi *ProviderInstance) TimerFlowCacheExpiry() {
fe := outpost.NewFlowExecutor(context.Background(), pi.flowSlug, pi.s.ac.Client.GetConfig(), log.Fields{})
func (db *DirectBinder) TimerFlowCacheExpiry() {
fe := outpost.NewFlowExecutor(context.Background(), db.si.GetFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{})
fe.Params.Add("goauthentik.io/outpost/ldap", "true")
fe.Params.Add("goauthentik.io/outpost/ldap-warmup", "true")
err := fe.WarmUp()
if err != nil {
pi.log.WithError(err).Warning("failed to warm up flow cache")
db.log.WithError(err).Warning("failed to warm up flow cache")
}
}

View File

@ -0,0 +1,55 @@
package bind
import (
"context"
"net"
"strings"
"github.com/getsentry/sentry-go"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/utils"
)
type Request struct {
BindDN string
BindPW string
id string
conn net.Conn
log *log.Entry
ctx context.Context
}
func NewRequest(bindDN string, bindPW string, conn net.Conn) (*Request, *sentry.Span) {
span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.bind",
sentry.TransactionName("authentik.providers.ldap.bind"))
rid := uuid.New().String()
span.SetTag("request_uid", rid)
span.SetTag("user.username", bindDN)
bindDN = strings.ToLower(bindDN)
return &Request{
BindDN: bindDN,
BindPW: bindPW,
conn: conn,
log: log.WithField("bindDN", bindDN).WithField("requestId", rid).WithField("client", utils.GetIP(conn.RemoteAddr())),
id: rid,
ctx: span.Context(),
}, span
}
func (r *Request) Context() context.Context {
return r.ctx
}
func (r *Request) Log() *log.Entry {
return r.log
}
func (r *Request) RemoteAddr() string {
return utils.GetIP(r.conn.RemoteAddr())
}
func (r *Request) ID() string {
return r.id
}

View File

@ -0,0 +1,21 @@
package constants
const (
OCGroup = "group"
OCGroupOfUniqueNames = "groupOfUniqueNames"
OCAKGroup = "goauthentik.io/ldap/group"
OCAKVirtualGroup = "goauthentik.io/ldap/virtual-group"
)
const (
OCUser = "user"
OCOrgPerson = "organizationalPerson"
OCInetOrgPerson = "inetOrgPerson"
OCAKUser = "goauthentik.io/ldap/user"
)
const (
OUUsers = "users"
OUGroups = "groups"
OUVirtualGroups = "virtual-groups"
)

View File

@ -0,0 +1,39 @@
package ldap
import (
"github.com/nmcclain/ldap"
"goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/group"
"goauthentik.io/internal/outpost/ldap/utils"
)
func (pi *ProviderInstance) UserEntry(u api.User) *ldap.Entry {
dn := pi.GetUserDN(u.Username)
attrs := utils.AKAttrsToLDAP(u.Attributes)
attrs = utils.EnsureAttributes(attrs, map[string][]string{
"memberOf": pi.GroupsForUser(u),
// Old fields for backwards compatibility
"accountStatus": {utils.BoolToString(*u.IsActive)},
"superuser": {utils.BoolToString(u.IsSuperuser)},
// End old fields
"goauthentik.io/ldap/active": {utils.BoolToString(*u.IsActive)},
"goauthentik.io/ldap/superuser": {utils.BoolToString(u.IsSuperuser)},
"cn": {u.Username},
"sAMAccountName": {u.Username},
"uid": {u.Uid},
"name": {u.Name},
"displayName": {u.Name},
"mail": {*u.Email},
"objectClass": {constants.OCUser, constants.OCOrgPerson, constants.OCInetOrgPerson, constants.OCAKUser},
"uidNumber": {pi.GetUidNumber(u)},
"gidNumber": {pi.GetUidNumber(u)},
})
return &ldap.Entry{DN: dn, Attributes: attrs}
}
func (pi *ProviderInstance) GroupEntry(g group.LDAPGroup) *ldap.Entry {
// TODO: Remove
return g.Entry()
}

View File

@ -0,0 +1,9 @@
package flags
import "goauthentik.io/api"
type UserFlags struct {
UserInfo *api.User
UserPk int32
CanSearch bool
}

View File

@ -0,0 +1,66 @@
package group
import (
"github.com/nmcclain/ldap"
"goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/server"
"goauthentik.io/internal/outpost/ldap/utils"
)
type LDAPGroup struct {
DN string
CN string
Uid string
GidNumber string
Member []string
IsSuperuser bool
IsVirtualGroup bool
AKAttributes interface{}
}
func (lg *LDAPGroup) Entry() *ldap.Entry {
attrs := utils.AKAttrsToLDAP(lg.AKAttributes)
objectClass := []string{constants.OCGroup, constants.OCGroupOfUniqueNames, constants.OCAKGroup}
if lg.IsVirtualGroup {
objectClass = append(objectClass, constants.OCAKVirtualGroup)
}
attrs = utils.EnsureAttributes(attrs, map[string][]string{
"objectClass": objectClass,
"member": lg.Member,
"goauthentik.io/ldap/superuser": {utils.BoolToString(lg.IsSuperuser)},
"cn": {lg.CN},
"uid": {lg.Uid},
"sAMAccountName": {lg.CN},
"gidNumber": {lg.GidNumber},
})
return &ldap.Entry{DN: lg.DN, Attributes: attrs}
}
func FromAPIGroup(g api.Group, si server.LDAPServerInstance) *LDAPGroup {
return &LDAPGroup{
DN: si.GetGroupDN(g.Name),
CN: g.Name,
Uid: string(g.Pk),
GidNumber: si.GetGidNumber(g),
Member: si.UsersForGroup(g),
IsVirtualGroup: false,
IsSuperuser: *g.IsSuperuser,
AKAttributes: g.Attributes,
}
}
func FromAPIUser(u api.User, si server.LDAPServerInstance) *LDAPGroup {
return &LDAPGroup{
DN: si.GetVirtualGroupDN(u.Username),
CN: u.Username,
Uid: u.Uid,
GidNumber: si.GetUidNumber(u),
Member: []string{si.GetUserDN(u.Username)},
IsVirtualGroup: true,
IsSuperuser: false,
AKAttributes: nil,
}
}

View File

@ -0,0 +1,4 @@
package handler
type Handler interface {
}

View File

@ -0,0 +1,83 @@
package ldap
import (
"crypto/tls"
"sync"
"github.com/go-openapi/strfmt"
log "github.com/sirupsen/logrus"
"goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/bind"
"goauthentik.io/internal/outpost/ldap/flags"
"goauthentik.io/internal/outpost/ldap/search"
)
type ProviderInstance struct {
BaseDN string
UserDN string
VirtualGroupDN string
GroupDN string
searcher search.Searcher
binder bind.Binder
appSlug string
flowSlug string
s *LDAPServer
log *log.Entry
tlsServerName *string
cert *tls.Certificate
outpostName string
searchAllowedGroups []*strfmt.UUID
boundUsersMutex sync.RWMutex
boundUsers map[string]flags.UserFlags
uidStartNumber int32
gidStartNumber int32
}
func (pi *ProviderInstance) GetAPIClient() *api.APIClient {
return pi.s.ac.Client
}
func (pi *ProviderInstance) GetBaseDN() string {
return pi.BaseDN
}
func (pi *ProviderInstance) GetBaseGroupDN() string {
return pi.GroupDN
}
func (pi *ProviderInstance) GetBaseUserDN() string {
return pi.UserDN
}
func (pi *ProviderInstance) GetOutpostName() string {
return pi.outpostName
}
func (pi *ProviderInstance) GetFlags(dn string) (flags.UserFlags, bool) {
pi.boundUsersMutex.RLock()
flags, ok := pi.boundUsers[dn]
pi.boundUsersMutex.RUnlock()
return flags, ok
}
func (pi *ProviderInstance) SetFlags(dn string, flag flags.UserFlags) {
pi.boundUsersMutex.Lock()
pi.boundUsers[dn] = flag
pi.boundUsersMutex.Unlock()
}
func (pi *ProviderInstance) GetAppSlug() string {
return pi.appSlug
}
func (pi *ProviderInstance) GetFlowSlug() string {
return pi.flowSlug
}
func (pi *ProviderInstance) GetSearchAllowedGroups() []*strfmt.UUID {
return pi.searchAllowedGroups
}

View File

@ -1,244 +0,0 @@
package ldap
import (
"errors"
"fmt"
"strings"
"sync"
"github.com/getsentry/sentry-go"
"github.com/nmcclain/ldap"
"github.com/prometheus/client_golang/prometheus"
"goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/metrics"
"goauthentik.io/internal/utils"
)
func (pi *ProviderInstance) SearchMe(req SearchRequest, f UserFlags) (ldap.ServerSearchResult, error) {
if f.UserInfo == nil {
u, _, err := pi.s.ac.Client.CoreApi.CoreUsersRetrieve(req.ctx, f.UserPk).Execute()
if err != nil {
req.log.WithError(err).Warning("Failed to get user info")
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("failed to get userinfo")
}
f.UserInfo = &u
}
entries := make([]*ldap.Entry, 1)
entries[0] = pi.UserEntry(*f.UserInfo)
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
}
func (pi *ProviderInstance) Search(req SearchRequest) (ldap.ServerSearchResult, error) {
accsp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.check_access")
baseDN := strings.ToLower("," + pi.BaseDN)
entries := []*ldap.Entry{}
filterEntity, err := ldap.GetFilterObjectClass(req.Filter)
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"type": "search",
"reason": "filter_parse_fail",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
}
if len(req.BindDN) < 1 {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"type": "search",
"reason": "empty_bind_dn",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: Anonymous BindDN not allowed %s", req.BindDN)
}
if !strings.HasSuffix(req.BindDN, baseDN) {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"type": "search",
"reason": "invalid_bind_dn",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, pi.BaseDN)
}
pi.boundUsersMutex.RLock()
flags, ok := pi.boundUsers[req.BindDN]
pi.boundUsersMutex.RUnlock()
if !ok {
pi.log.Debug("User info not cached")
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"type": "search",
"reason": "user_info_not_cached",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("access denied")
}
if req.SearchRequest.Scope == ldap.ScopeBaseObject {
pi.log.Debug("base scope, showing domain info")
return pi.SearchBase(req, flags.CanSearch)
}
if !flags.CanSearch {
pi.log.Debug("User can't search, showing info about user")
return pi.SearchMe(req, flags)
}
accsp.Finish()
parsedFilter, err := ldap.CompileFilter(req.Filter)
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"type": "search",
"reason": "filter_parse_fail",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
}
// Create a custom client to set additional headers
c := api.NewAPIClient(pi.s.ac.Client.GetConfig())
c.GetConfig().AddDefaultHeader("X-authentik-outpost-ldap-query", req.Filter)
switch filterEntity {
default:
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"type": "search",
"reason": "unhandled_filter_type",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: unhandled filter type: %s [%s]", filterEntity, req.Filter)
case "groupOfUniqueNames":
fallthrough
case "goauthentik.io/ldap/group":
fallthrough
case "goauthentik.io/ldap/virtual-group":
fallthrough
case GroupObjectClass:
wg := sync.WaitGroup{}
wg.Add(2)
gEntries := make([]*ldap.Entry, 0)
uEntries := make([]*ldap.Entry, 0)
go func() {
defer wg.Done()
gapisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.api_group")
searchReq, skip := parseFilterForGroup(c.CoreApi.CoreGroupsList(gapisp.Context()), parsedFilter, false)
if skip {
pi.log.Trace("Skip backend request")
return
}
groups, _, err := searchReq.Execute()
gapisp.Finish()
if err != nil {
req.log.WithError(err).Warning("failed to get groups")
return
}
pi.log.WithField("count", len(groups.Results)).Trace("Got results from API")
for _, g := range groups.Results {
gEntries = append(gEntries, pi.GroupEntry(pi.APIGroupToLDAPGroup(g)))
}
}()
go func() {
defer wg.Done()
uapisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.api_user")
searchReq, skip := parseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false)
if skip {
pi.log.Trace("Skip backend request")
return
}
users, _, err := searchReq.Execute()
uapisp.Finish()
if err != nil {
req.log.WithError(err).Warning("failed to get users")
return
}
for _, u := range users.Results {
uEntries = append(uEntries, pi.GroupEntry(pi.APIUserToLDAPGroup(u)))
}
}()
wg.Wait()
entries = append(gEntries, uEntries...)
case "":
fallthrough
case "organizationalPerson":
fallthrough
case "inetOrgPerson":
fallthrough
case "goauthentik.io/ldap/user":
fallthrough
case UserObjectClass:
uapisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.api_user")
searchReq, skip := parseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false)
if skip {
pi.log.Trace("Skip backend request")
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
}
users, _, err := searchReq.Execute()
uapisp.Finish()
if err != nil {
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("API Error: %s", err)
}
for _, u := range users.Results {
entries = append(entries, pi.UserEntry(u))
}
}
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
}
func (pi *ProviderInstance) UserEntry(u api.User) *ldap.Entry {
dn := pi.GetUserDN(u.Username)
attrs := AKAttrsToLDAP(u.Attributes)
attrs = pi.ensureAttributes(attrs, map[string][]string{
"memberOf": pi.GroupsForUser(u),
// Old fields for backwards compatibility
"accountStatus": {BoolToString(*u.IsActive)},
"superuser": {BoolToString(u.IsSuperuser)},
"goauthentik.io/ldap/active": {BoolToString(*u.IsActive)},
"goauthentik.io/ldap/superuser": {BoolToString(u.IsSuperuser)},
"cn": {u.Username},
"sAMAccountName": {u.Username},
"uid": {u.Uid},
"name": {u.Name},
"displayName": {u.Name},
"mail": {*u.Email},
"objectClass": {UserObjectClass, "organizationalPerson", "inetOrgPerson", "goauthentik.io/ldap/user"},
"uidNumber": {pi.GetUidNumber(u)},
"gidNumber": {pi.GetUidNumber(u)},
})
return &ldap.Entry{DN: dn, Attributes: attrs}
}
func (pi *ProviderInstance) GroupEntry(g LDAPGroup) *ldap.Entry {
attrs := AKAttrsToLDAP(g.akAttributes)
objectClass := []string{GroupObjectClass, "groupOfUniqueNames", "goauthentik.io/ldap/group"}
if g.isVirtualGroup {
objectClass = append(objectClass, "goauthentik.io/ldap/virtual-group")
}
attrs = pi.ensureAttributes(attrs, map[string][]string{
"objectClass": objectClass,
"member": g.member,
"goauthentik.io/ldap/superuser": {BoolToString(g.isSuperuser)},
"cn": {g.cn},
"uid": {g.uid},
"sAMAccountName": {g.cn},
"gidNumber": {g.gidNumber},
})
return &ldap.Entry{DN: g.dn, Attributes: attrs}
}

View File

@ -2,50 +2,18 @@ package ldap
import (
"crypto/tls"
"net"
"sync"
"github.com/go-openapi/strfmt"
"github.com/pires/go-proxyproto"
log "github.com/sirupsen/logrus"
"goauthentik.io/api"
"goauthentik.io/internal/crypto"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/ldap/metrics"
"github.com/nmcclain/ldap"
)
const GroupObjectClass = "group"
const UserObjectClass = "user"
type ProviderInstance struct {
BaseDN string
UserDN string
VirtualGroupDN string
GroupDN string
appSlug string
flowSlug string
s *LDAPServer
log *log.Entry
tlsServerName *string
cert *tls.Certificate
outpostName string
searchAllowedGroups []*strfmt.UUID
boundUsersMutex sync.RWMutex
boundUsers map[string]UserFlags
uidStartNumber int32
gidStartNumber int32
}
type UserFlags struct {
UserInfo *api.User
UserPk int32
CanSearch bool
}
type LDAPServer struct {
s *ldap.Server
log *log.Entry
@ -55,17 +23,6 @@ type LDAPServer struct {
providers []*ProviderInstance
}
type LDAPGroup struct {
dn string
cn string
uid string
gidNumber string
member []string
isSuperuser bool
isVirtualGroup bool
akAttributes interface{}
}
func NewServer(ac *ak.APIController) *LDAPServer {
s := ldap.NewServer()
s.EnforceLDAP = true
@ -90,3 +47,54 @@ func NewServer(ac *ak.APIController) *LDAPServer {
func (ls *LDAPServer) Type() string {
return "ldap"
}
func (ls *LDAPServer) StartLDAPServer() error {
listen := "0.0.0.0:3389"
ln, err := net.Listen("tcp", listen)
if err != nil {
ls.log.WithField("listen", listen).WithError(err).Fatalf("FATAL: listen failed")
}
proxyListener := &proxyproto.Listener{Listener: ln}
defer proxyListener.Close()
ls.log.WithField("listen", listen).Info("Starting ldap server")
err = ls.s.Serve(proxyListener)
if err != nil {
return err
}
ls.log.Printf("closing %s", ln.Addr())
return ls.s.ListenAndServe(listen)
}
func (ls *LDAPServer) Start() error {
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
defer wg.Done()
metrics.RunServer()
}()
go func() {
defer wg.Done()
err := ls.StartLDAPServer()
if err != nil {
panic(err)
}
}()
go func() {
defer wg.Done()
err := ls.StartLDAPTLSServer()
if err != nil {
panic(err)
}
}()
wg.Wait()
return nil
}
func (ls *LDAPServer) TimerFlowCacheExpiry() {
for _, p := range ls.providers {
ls.log.WithField("flow", p.flowSlug).Debug("Pre-heating flow cache")
p.binder.TimerFlowCacheExpiry()
}
}

View File

@ -0,0 +1,55 @@
package ldap
import (
"crypto/tls"
"net"
"github.com/pires/go-proxyproto"
)
func (ls *LDAPServer) getCertificates(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
if len(ls.providers) == 1 {
if ls.providers[0].cert != nil {
ls.log.WithField("server-name", info.ServerName).Debug("We only have a single provider, using their cert")
return ls.providers[0].cert, nil
}
}
for _, provider := range ls.providers {
if provider.tlsServerName == &info.ServerName {
if provider.cert == nil {
ls.log.WithField("server-name", info.ServerName).Debug("Handler does not have a certificate")
return ls.defaultCert, nil
}
return provider.cert, nil
}
}
ls.log.WithField("server-name", info.ServerName).Debug("Fallback to default cert")
return ls.defaultCert, nil
}
func (ls *LDAPServer) StartLDAPTLSServer() error {
listen := "0.0.0.0:6636"
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS12,
GetCertificate: ls.getCertificates,
}
ln, err := net.Listen("tcp", listen)
if err != nil {
ls.log.WithField("listen", listen).WithError(err).Fatalf("FATAL: listen failed")
}
proxyListener := &proxyproto.Listener{Listener: ln}
defer proxyListener.Close()
tln := tls.NewListener(proxyListener, tlsConfig)
ls.log.WithField("listen", listen).Info("Starting ldap tls server")
err = ls.s.Serve(tln)
if err != nil {
return err
}
ls.log.Printf("closing %s", ln.Addr())
return ls.s.ListenAndServe(listen)
}

View File

@ -2,23 +2,19 @@ package ldap
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"strings"
"sync"
"github.com/go-openapi/strfmt"
"github.com/pires/go-proxyproto"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/ldap/metrics"
)
const (
UsersOU = "users"
GroupsOU = "groups"
VirtualGroupsOU = "virtual-groups"
"goauthentik.io/api"
directbind "goauthentik.io/internal/outpost/ldap/bind/direct"
"goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/flags"
directsearch "goauthentik.io/internal/outpost/ldap/search/direct"
memorysearch "goauthentik.io/internal/outpost/ldap/search/memory"
)
func (ls *LDAPServer) Refresh() error {
@ -31,9 +27,9 @@ func (ls *LDAPServer) Refresh() error {
}
providers := make([]*ProviderInstance, len(outposts.Results))
for idx, provider := range outposts.Results {
userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", UsersOU, *provider.BaseDn))
groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", GroupsOU, *provider.BaseDn))
virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", VirtualGroupsOU, *provider.BaseDn))
userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUUsers, *provider.BaseDn))
groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn))
virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUVirtualGroups, *provider.BaseDn))
logger := log.WithField("logger", "authentik.outpost.ldap").WithField("provider", provider.Name)
providers[idx] = &ProviderInstance{
BaseDN: *provider.BaseDn,
@ -44,7 +40,7 @@ func (ls *LDAPServer) Refresh() error {
flowSlug: provider.BindFlowSlug,
searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())},
boundUsersMutex: sync.RWMutex{},
boundUsers: make(map[string]UserFlags),
boundUsers: make(map[string]flags.UserFlags),
s: ls,
log: logger,
tlsServerName: provider.TlsServerName,
@ -60,79 +56,14 @@ func (ls *LDAPServer) Refresh() error {
}
providers[idx].cert = ls.cs.Get(*kp)
}
if *provider.SearchMode.Ptr() == api.SEARCHMODEENUM_CACHED {
providers[idx].searcher = memorysearch.NewMemorySearcher(providers[idx])
} else if *provider.SearchMode.Ptr() == api.SEARCHMODEENUM_DIRECT {
providers[idx].searcher = directsearch.NewDirectSearcher(providers[idx])
}
providers[idx].binder = directbind.NewDirectBinder(providers[idx])
}
ls.providers = providers
ls.log.Info("Update providers")
return nil
}
func (ls *LDAPServer) StartLDAPServer() error {
listen := "0.0.0.0:3389"
ln, err := net.Listen("tcp", listen)
if err != nil {
ls.log.Fatalf("FATAL: listen (%s) failed - %s", listen, err)
}
proxyListener := &proxyproto.Listener{Listener: ln}
defer proxyListener.Close()
ls.log.WithField("listen", listen).Info("Starting ldap server")
err = ls.s.Serve(proxyListener)
if err != nil {
return err
}
ls.log.Printf("closing %s", ln.Addr())
return ls.s.ListenAndServe(listen)
}
func (ls *LDAPServer) StartLDAPTLSServer() error {
listen := "0.0.0.0:6636"
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS12,
GetCertificate: ls.getCertificates,
}
ln, err := net.Listen("tcp", listen)
if err != nil {
ls.log.Fatalf("FATAL: listen (%s) failed - %s", listen, err)
}
proxyListener := &proxyproto.Listener{Listener: ln}
defer proxyListener.Close()
tln := tls.NewListener(proxyListener, tlsConfig)
ls.log.WithField("listen", listen).Info("Starting ldap tls server")
err = ls.s.Serve(tln)
if err != nil {
return err
}
ls.log.Printf("closing %s", ln.Addr())
return ls.s.ListenAndServe(listen)
}
func (ls *LDAPServer) Start() error {
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
defer wg.Done()
metrics.RunServer()
}()
go func() {
defer wg.Done()
err := ls.StartLDAPServer()
if err != nil {
panic(err)
}
}()
go func() {
defer wg.Done()
err := ls.StartLDAPTLSServer()
if err != nil {
panic(err)
}
}()
wg.Wait()
return nil
}

View File

@ -1,47 +1,21 @@
package ldap
import (
"context"
"errors"
"net"
"strings"
"github.com/getsentry/sentry-go"
goldap "github.com/go-ldap/ldap/v3"
"github.com/google/uuid"
"github.com/nmcclain/ldap"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/ldap/metrics"
"goauthentik.io/internal/outpost/ldap/search"
"goauthentik.io/internal/utils"
)
type SearchRequest struct {
ldap.SearchRequest
BindDN string
id string
conn net.Conn
log *log.Entry
ctx context.Context
}
func (ls *LDAPServer) Search(bindDN string, searchReq ldap.SearchRequest, conn net.Conn) (ldap.ServerSearchResult, error) {
span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.search", sentry.TransactionName("authentik.providers.ldap.search"))
rid := uuid.New().String()
span.SetTag("request_uid", rid)
span.SetTag("user.username", bindDN)
span.SetTag("ak_filter", searchReq.Filter)
span.SetTag("ak_base_dn", searchReq.BaseDN)
bindDN = strings.ToLower(bindDN)
req := SearchRequest{
SearchRequest: searchReq,
BindDN: bindDN,
conn: conn,
log: ls.log.WithField("bindDN", bindDN).WithField("requestId", rid).WithField("scope", ldap.ScopeMap[searchReq.Scope]).WithField("client", utils.GetIP(conn.RemoteAddr())).WithField("filter", searchReq.Filter).WithField("baseDN", searchReq.BaseDN),
id: rid,
ctx: span.Context(),
}
req, span := search.NewRequest(bindDN, searchReq, conn)
defer func() {
span.Finish()
@ -50,9 +24,9 @@ func (ls *LDAPServer) Search(bindDN string, searchReq ldap.SearchRequest, conn n
"type": "search",
"filter": req.Filter,
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
"client": utils.GetIP(conn.RemoteAddr()),
}).Observe(float64(span.EndTime.Sub(span.StartTime)))
req.log.WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Search request")
req.Log().WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Search request")
}()
defer func() {
@ -69,13 +43,13 @@ func (ls *LDAPServer) Search(bindDN string, searchReq ldap.SearchRequest, conn n
}
bd, err := goldap.ParseDN(searchReq.BaseDN)
if err != nil {
req.log.WithError(err).Info("failed to parse basedn")
req.Log().WithError(err).Info("failed to parse basedn")
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, errors.New("invalid DN")
}
for _, provider := range ls.providers {
providerBase, _ := goldap.ParseDN(provider.BaseDN)
if providerBase.AncestorOf(bd) || providerBase.Equal(bd) {
return provider.Search(req)
return provider.searcher.Search(req)
}
}
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, errors.New("no provider could handle request")

View File

@ -1,13 +1,14 @@
package ldap
package direct
import (
"fmt"
"github.com/nmcclain/ldap"
"goauthentik.io/internal/constants"
"goauthentik.io/internal/outpost/ldap/search"
)
func (pi *ProviderInstance) SearchBase(req SearchRequest, authz bool) (ldap.ServerSearchResult, error) {
func (ds *DirectSearcher) SearchBase(req *search.Request, authz bool) (ldap.ServerSearchResult, error) {
dn := ""
if authz {
dn = req.SearchRequest.BaseDN
@ -19,7 +20,7 @@ func (pi *ProviderInstance) SearchBase(req SearchRequest, authz bool) (ldap.Serv
Attributes: []*ldap.EntryAttribute{
{
Name: "distinguishedName",
Values: []string{pi.BaseDN},
Values: []string{ds.si.GetBaseDN()},
},
{
Name: "objectClass",
@ -32,9 +33,9 @@ func (pi *ProviderInstance) SearchBase(req SearchRequest, authz bool) (ldap.Serv
{
Name: "namingContexts",
Values: []string{
pi.BaseDN,
pi.GroupDN,
pi.UserDN,
ds.si.GetBaseDN(),
ds.si.GetBaseUserDN(),
ds.si.GetBaseGroupDN(),
},
},
{

View File

@ -0,0 +1,219 @@
package direct
import (
"errors"
"fmt"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"github.com/getsentry/sentry-go"
"github.com/nmcclain/ldap"
"github.com/prometheus/client_golang/prometheus"
"goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/flags"
"goauthentik.io/internal/outpost/ldap/group"
"goauthentik.io/internal/outpost/ldap/metrics"
"goauthentik.io/internal/outpost/ldap/search"
"goauthentik.io/internal/outpost/ldap/server"
"goauthentik.io/internal/outpost/ldap/utils"
)
type DirectSearcher struct {
si server.LDAPServerInstance
log *log.Entry
}
func NewDirectSearcher(si server.LDAPServerInstance) *DirectSearcher {
ds := &DirectSearcher{
si: si,
log: log.WithField("logger", "authentik.outpost.ldap.searcher.direct"),
}
ds.log.Info("initialised direct searcher")
return ds
}
func (ds *DirectSearcher) SearchMe(req *search.Request, f flags.UserFlags) (ldap.ServerSearchResult, error) {
if f.UserInfo == nil {
u, _, err := ds.si.GetAPIClient().CoreApi.CoreUsersRetrieve(req.Context(), f.UserPk).Execute()
if err != nil {
req.Log().WithError(err).Warning("Failed to get user info")
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("failed to get userinfo")
}
f.UserInfo = &u
}
entries := make([]*ldap.Entry, 1)
entries[0] = ds.si.UserEntry(*f.UserInfo)
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
}
func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult, error) {
accsp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.check_access")
baseDN := strings.ToLower("," + ds.si.GetBaseDN())
entries := []*ldap.Entry{}
filterEntity, err := ldap.GetFilterObjectClass(req.Filter)
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ds.si.GetOutpostName(),
"type": "search",
"reason": "filter_parse_fail",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
}
if len(req.BindDN) < 1 {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ds.si.GetOutpostName(),
"type": "search",
"reason": "empty_bind_dn",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: Anonymous BindDN not allowed %s", req.BindDN)
}
if !strings.HasSuffix(req.BindDN, baseDN) {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ds.si.GetOutpostName(),
"type": "search",
"reason": "invalid_bind_dn",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, ds.si.GetBaseDN())
}
flags, ok := ds.si.GetFlags(req.BindDN)
if !ok {
req.Log().Debug("User info not cached")
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ds.si.GetOutpostName(),
"type": "search",
"reason": "user_info_not_cached",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("access denied")
}
if req.Scope == ldap.ScopeBaseObject {
req.Log().Debug("base scope, showing domain info")
return ds.SearchBase(req, flags.CanSearch)
}
if !flags.CanSearch {
req.Log().Debug("User can't search, showing info about user")
return ds.SearchMe(req, flags)
}
accsp.Finish()
parsedFilter, err := ldap.CompileFilter(req.Filter)
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ds.si.GetOutpostName(),
"type": "search",
"reason": "filter_parse_fail",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
}
// Create a custom client to set additional headers
c := api.NewAPIClient(ds.si.GetAPIClient().GetConfig())
c.GetConfig().AddDefaultHeader("X-authentik-outpost-ldap-query", req.Filter)
switch filterEntity {
default:
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ds.si.GetOutpostName(),
"type": "search",
"reason": "unhandled_filter_type",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: unhandled filter type: %s [%s]", filterEntity, req.Filter)
case constants.OCGroupOfUniqueNames:
fallthrough
case constants.OCAKGroup:
fallthrough
case constants.OCAKVirtualGroup:
fallthrough
case constants.OCGroup:
wg := sync.WaitGroup{}
wg.Add(2)
gEntries := make([]*ldap.Entry, 0)
uEntries := make([]*ldap.Entry, 0)
go func() {
defer wg.Done()
gapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_group")
searchReq, skip := utils.ParseFilterForGroup(c.CoreApi.CoreGroupsList(gapisp.Context()), parsedFilter, false)
if skip {
req.Log().Trace("Skip backend request")
return
}
groups, _, err := searchReq.Execute()
gapisp.Finish()
if err != nil {
req.Log().WithError(err).Warning("failed to get groups")
return
}
req.Log().WithField("count", len(groups.Results)).Trace("Got results from API")
for _, g := range groups.Results {
gEntries = append(gEntries, group.FromAPIGroup(g, ds.si).Entry())
}
}()
go func() {
defer wg.Done()
uapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_user")
searchReq, skip := utils.ParseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false)
if skip {
req.Log().Trace("Skip backend request")
return
}
users, _, err := searchReq.Execute()
uapisp.Finish()
if err != nil {
req.Log().WithError(err).Warning("failed to get users")
return
}
for _, u := range users.Results {
uEntries = append(uEntries, group.FromAPIUser(u, ds.si).Entry())
}
}()
wg.Wait()
entries = append(gEntries, uEntries...)
case "":
fallthrough
case constants.OCOrgPerson:
fallthrough
case constants.OCInetOrgPerson:
fallthrough
case constants.OCAKUser:
fallthrough
case constants.OCUser:
uapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_user")
searchReq, skip := utils.ParseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false)
if skip {
req.Log().Trace("Skip backend request")
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
}
users, _, err := searchReq.Execute()
uapisp.Finish()
if err != nil {
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("API Error: %s", err)
}
for _, u := range users.Results {
entries = append(entries, ds.si.UserEntry(u))
}
}
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
}

View File

@ -0,0 +1,54 @@
package memory
import (
"fmt"
"github.com/nmcclain/ldap"
"goauthentik.io/internal/constants"
"goauthentik.io/internal/outpost/ldap/search"
)
func (ms *MemorySearcher) SearchBase(req *search.Request, authz bool) (ldap.ServerSearchResult, error) {
dn := ""
if authz {
dn = req.SearchRequest.BaseDN
}
return ldap.ServerSearchResult{
Entries: []*ldap.Entry{
{
DN: dn,
Attributes: []*ldap.EntryAttribute{
{
Name: "distinguishedName",
Values: []string{ms.si.GetBaseDN()},
},
{
Name: "objectClass",
Values: []string{"top", "domain"},
},
{
Name: "supportedLDAPVersion",
Values: []string{"3"},
},
{
Name: "namingContexts",
Values: []string{
ms.si.GetBaseDN(),
ms.si.GetBaseUserDN(),
ms.si.GetBaseGroupDN(),
},
},
{
Name: "vendorName",
Values: []string{"goauthentik.io"},
},
{
Name: "vendorVersion",
Values: []string{fmt.Sprintf("authentik LDAP Outpost Version %s (build %s)", constants.VERSION, constants.BUILD())},
},
},
},
},
Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess,
}, nil
}

View File

@ -0,0 +1,63 @@
package memory
import (
"context"
"goauthentik.io/api"
)
const pageSize = 100
func (ms *MemorySearcher) FetchUsers() []api.User {
fetchUsersOffset := func(page int) (*api.PaginatedUserList, error) {
users, _, err := ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).Page(int32(page)).PageSize(pageSize).Execute()
if err != nil {
ms.log.WithError(err).Warning("failed to update users")
return nil, err
}
ms.log.WithField("page", page).Debug("fetched users")
return &users, nil
}
page := 1
users := make([]api.User, 0)
for {
apiUsers, err := fetchUsersOffset(page)
if err != nil {
return users
}
if apiUsers.Pagination.Next > 0 {
page += 1
} else {
break
}
users = append(users, apiUsers.Results...)
}
return users
}
func (ms *MemorySearcher) FetchGroups() []api.Group {
fetchGroupsOffset := func(page int) (*api.PaginatedGroupList, error) {
groups, _, err := ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()).Page(int32(page)).PageSize(pageSize).Execute()
if err != nil {
ms.log.WithError(err).Warning("failed to update groups")
return nil, err
}
ms.log.WithField("page", page).Debug("fetched groups")
return &groups, nil
}
page := 1
groups := make([]api.Group, 0)
for {
apiGroups, err := fetchGroupsOffset(page)
if err != nil {
return groups
}
if apiGroups.Pagination.Next > 0 {
page += 1
} else {
break
}
groups = append(groups, apiGroups.Results...)
}
return groups
}

View File

@ -0,0 +1,182 @@
package memory
import (
"errors"
"fmt"
"strings"
"sync"
"github.com/getsentry/sentry-go"
"github.com/nmcclain/ldap"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/flags"
"goauthentik.io/internal/outpost/ldap/group"
"goauthentik.io/internal/outpost/ldap/metrics"
"goauthentik.io/internal/outpost/ldap/search"
"goauthentik.io/internal/outpost/ldap/server"
)
type MemorySearcher struct {
si server.LDAPServerInstance
log *log.Entry
users []api.User
groups []api.Group
}
func NewMemorySearcher(si server.LDAPServerInstance) *MemorySearcher {
ms := &MemorySearcher{
si: si,
log: log.WithField("logger", "authentik.outpost.ldap.searcher.memory"),
}
ms.log.Info("initialised memory searcher")
ms.users = ms.FetchUsers()
ms.groups = ms.FetchGroups()
return ms
}
func (ms *MemorySearcher) SearchMe(req *search.Request, f flags.UserFlags) (ldap.ServerSearchResult, error) {
if f.UserInfo == nil {
for _, u := range ms.users {
if u.Pk == f.UserPk {
f.UserInfo = &u
}
}
if f.UserInfo == nil {
req.Log().WithField("pk", f.UserPk).Warning("User with pk is not in local cache")
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("failed to get userinfo")
}
}
entries := make([]*ldap.Entry, 1)
entries[0] = ms.si.UserEntry(*f.UserInfo)
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
}
func (ms *MemorySearcher) Search(req *search.Request) (ldap.ServerSearchResult, error) {
accsp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.check_access")
baseDN := strings.ToLower("," + ms.si.GetBaseDN())
entries := []*ldap.Entry{}
filterEntity, err := ldap.GetFilterObjectClass(req.Filter)
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ms.si.GetOutpostName(),
"type": "search",
"reason": "filter_parse_fail",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
}
if len(req.BindDN) < 1 {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ms.si.GetOutpostName(),
"type": "search",
"reason": "empty_bind_dn",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: Anonymous BindDN not allowed %s", req.BindDN)
}
if !strings.HasSuffix(req.BindDN, baseDN) {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ms.si.GetOutpostName(),
"type": "search",
"reason": "invalid_bind_dn",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, ms.si.GetBaseDN())
}
flags, ok := ms.si.GetFlags(req.BindDN)
if !ok {
req.Log().Debug("User info not cached")
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ms.si.GetOutpostName(),
"type": "search",
"reason": "user_info_not_cached",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("access denied")
}
if req.Scope == ldap.ScopeBaseObject {
req.Log().Debug("base scope, showing domain info")
return ms.SearchBase(req, flags.CanSearch)
}
if !flags.CanSearch {
req.Log().Debug("User can't search, showing info about user")
return ms.SearchMe(req, flags)
}
accsp.Finish()
// parsedFilter, err := ldap.CompileFilter(req.Filter)
// if err != nil {
// metrics.RequestsRejected.With(prometheus.Labels{
// "outpost_name": ms.si.GetOutpostName(),
// "type": "search",
// "reason": "filter_parse_fail",
// "dn": req.BindDN,
// "client": req.RemoteAddr(),
// }).Inc()
// return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
// }
switch filterEntity {
default:
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ms.si.GetOutpostName(),
"type": "search",
"reason": "unhandled_filter_type",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: unhandled filter type: %s [%s]", filterEntity, req.Filter)
case constants.OCGroupOfUniqueNames:
fallthrough
case constants.OCAKGroup:
fallthrough
case constants.OCAKVirtualGroup:
fallthrough
case constants.OCGroup:
wg := sync.WaitGroup{}
wg.Add(2)
gEntries := make([]*ldap.Entry, 0)
uEntries := make([]*ldap.Entry, 0)
go func() {
defer wg.Done()
for _, g := range ms.groups {
gEntries = append(gEntries, group.FromAPIGroup(g, ms.si).Entry())
}
}()
go func() {
defer wg.Done()
for _, u := range ms.users {
uEntries = append(uEntries, group.FromAPIUser(u, ms.si).Entry())
}
}()
wg.Wait()
entries = append(gEntries, uEntries...)
case "":
fallthrough
case constants.OCOrgPerson:
fallthrough
case constants.OCInetOrgPerson:
fallthrough
case constants.OCAKUser:
fallthrough
case constants.OCUser:
for _, u := range ms.users {
entries = append(entries, ms.si.UserEntry(u))
}
}
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
}

View File

@ -0,0 +1,53 @@
package search
import (
"context"
"net"
"strings"
"github.com/getsentry/sentry-go"
"github.com/google/uuid"
"github.com/nmcclain/ldap"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/utils"
)
type Request struct {
ldap.SearchRequest
BindDN string
log *log.Entry
id string
conn net.Conn
ctx context.Context
}
func NewRequest(bindDN string, searchReq ldap.SearchRequest, conn net.Conn) (*Request, *sentry.Span) {
rid := uuid.New().String()
bindDN = strings.ToLower(bindDN)
span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.search", sentry.TransactionName("authentik.providers.ldap.search"))
span.SetTag("request_uid", rid)
span.SetTag("user.username", bindDN)
span.SetTag("ak_filter", searchReq.Filter)
span.SetTag("ak_base_dn", searchReq.BaseDN)
return &Request{
SearchRequest: searchReq,
BindDN: bindDN,
conn: conn,
log: log.WithField("bindDN", bindDN).WithField("requestId", rid).WithField("scope", ldap.ScopeMap[searchReq.Scope]).WithField("client", utils.GetIP(conn.RemoteAddr())).WithField("filter", searchReq.Filter).WithField("baseDN", searchReq.BaseDN),
id: rid,
ctx: span.Context(),
}, span
}
func (r *Request) Context() context.Context {
return r.ctx
}
func (r *Request) Log() *log.Entry {
return r.log
}
func (r *Request) RemoteAddr() string {
return utils.GetIP(r.conn.RemoteAddr())
}

View File

@ -0,0 +1,7 @@
package search
import "github.com/nmcclain/ldap"
type Searcher interface {
Search(req *Request) (ldap.ServerSearchResult, error)
}

View File

@ -0,0 +1,35 @@
package server
import (
"github.com/go-openapi/strfmt"
"github.com/nmcclain/ldap"
"goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/flags"
)
type LDAPServerInstance interface {
GetAPIClient() *api.APIClient
GetOutpostName() string
GetFlowSlug() string
GetAppSlug() string
GetSearchAllowedGroups() []*strfmt.UUID
UserEntry(u api.User) *ldap.Entry
GetBaseDN() string
GetBaseGroupDN() string
GetBaseUserDN() string
GetUserDN(string) string
GetGroupDN(string) string
GetVirtualGroupDN(string) string
GetUidNumber(api.User) string
GetGidNumber(api.Group) string
UsersForGroup(api.Group) []string
GetFlags(string) (flags.UserFlags, bool)
SetFlags(string, flags.UserFlags)
}

View File

@ -3,70 +3,12 @@ package ldap
import (
"fmt"
"math/big"
"reflect"
"strconv"
"strings"
"github.com/nmcclain/ldap"
log "github.com/sirupsen/logrus"
"goauthentik.io/api"
)
func BoolToString(in bool) string {
if in {
return "true"
}
return "false"
}
func ldapResolveTypeSingle(in interface{}) *string {
switch t := in.(type) {
case string:
return &t
case *string:
return t
case bool:
s := BoolToString(t)
return &s
case *bool:
s := BoolToString(*t)
return &s
default:
log.WithField("type", reflect.TypeOf(in).String()).Warning("Type can't be mapped to LDAP yet")
return nil
}
}
func AKAttrsToLDAP(attrs interface{}) []*ldap.EntryAttribute {
attrList := []*ldap.EntryAttribute{}
if attrs == nil {
return attrList
}
a := attrs.(*map[string]interface{})
for attrKey, attrValue := range *a {
entry := &ldap.EntryAttribute{Name: attrKey}
switch t := attrValue.(type) {
case []string:
entry.Values = t
case *[]string:
entry.Values = *t
case []interface{}:
entry.Values = make([]string, len(t))
for idx, v := range t {
v := ldapResolveTypeSingle(v)
entry.Values[idx] = *v
}
default:
v := ldapResolveTypeSingle(t)
if v != nil {
entry.Values = []string{*v}
}
}
attrList = append(attrList, entry)
}
return attrList
}
func (pi *ProviderInstance) GroupsForUser(user api.User) []string {
groups := make([]string, len(user.Groups))
for i, group := range user.GroupsObj {
@ -83,32 +25,6 @@ func (pi *ProviderInstance) UsersForGroup(group api.Group) []string {
return users
}
func (pi *ProviderInstance) APIGroupToLDAPGroup(g api.Group) LDAPGroup {
return LDAPGroup{
dn: pi.GetGroupDN(g.Name),
cn: g.Name,
uid: string(g.Pk),
gidNumber: pi.GetGidNumber(g),
member: pi.UsersForGroup(g),
isVirtualGroup: false,
isSuperuser: *g.IsSuperuser,
akAttributes: g.Attributes,
}
}
func (pi *ProviderInstance) APIUserToLDAPGroup(u api.User) LDAPGroup {
return LDAPGroup{
dn: pi.GetVirtualGroupDN(u.Username),
cn: u.Username,
uid: u.Uid,
gidNumber: pi.GetUidNumber(u),
member: []string{pi.GetUserDN(u.Username)},
isVirtualGroup: true,
isSuperuser: false,
akAttributes: nil,
}
}
func (pi *ProviderInstance) GetUserDN(user string) string {
return fmt.Sprintf("cn=%s,%s", user, pi.UserDN)
}
@ -155,26 +71,3 @@ func (pi *ProviderInstance) GetRIDForGroup(uid string) int32 {
return int32(gid)
}
func (pi *ProviderInstance) ensureAttributes(attrs []*ldap.EntryAttribute, shouldHave map[string][]string) []*ldap.EntryAttribute {
for name, values := range shouldHave {
attrs = pi.mustHaveAttribute(attrs, name, values)
}
return attrs
}
func (pi *ProviderInstance) mustHaveAttribute(attrs []*ldap.EntryAttribute, name string, value []string) []*ldap.EntryAttribute {
shouldSet := true
for _, attr := range attrs {
if attr.Name == name {
shouldSet = false
}
}
if shouldSet {
return append(attrs, &ldap.EntryAttribute{
Name: name,
Values: value,
})
}
return attrs
}

View File

@ -0,0 +1,86 @@
package utils
import (
"reflect"
"github.com/nmcclain/ldap"
log "github.com/sirupsen/logrus"
)
func BoolToString(in bool) string {
if in {
return "true"
}
return "false"
}
func ldapResolveTypeSingle(in interface{}) *string {
switch t := in.(type) {
case string:
return &t
case *string:
return t
case bool:
s := BoolToString(t)
return &s
case *bool:
s := BoolToString(*t)
return &s
default:
log.WithField("type", reflect.TypeOf(in).String()).Warning("Type can't be mapped to LDAP yet")
return nil
}
}
func AKAttrsToLDAP(attrs interface{}) []*ldap.EntryAttribute {
attrList := []*ldap.EntryAttribute{}
if attrs == nil {
return attrList
}
a := attrs.(*map[string]interface{})
for attrKey, attrValue := range *a {
entry := &ldap.EntryAttribute{Name: attrKey}
switch t := attrValue.(type) {
case []string:
entry.Values = t
case *[]string:
entry.Values = *t
case []interface{}:
entry.Values = make([]string, len(t))
for idx, v := range t {
v := ldapResolveTypeSingle(v)
entry.Values[idx] = *v
}
default:
v := ldapResolveTypeSingle(t)
if v != nil {
entry.Values = []string{*v}
}
}
attrList = append(attrList, entry)
}
return attrList
}
func EnsureAttributes(attrs []*ldap.EntryAttribute, shouldHave map[string][]string) []*ldap.EntryAttribute {
for name, values := range shouldHave {
attrs = MustHaveAttribute(attrs, name, values)
}
return attrs
}
func MustHaveAttribute(attrs []*ldap.EntryAttribute, name string, value []string) []*ldap.EntryAttribute {
shouldSet := true
for _, attr := range attrs {
if attr.Name == name {
shouldSet = false
}
}
if shouldSet {
return append(attrs, &ldap.EntryAttribute{
Name: name,
Values: value,
})
}
return attrs
}

View File

@ -1,19 +1,20 @@
package ldap
package utils
import (
goldap "github.com/go-ldap/ldap/v3"
ber "github.com/nmcclain/asn1-ber"
"github.com/nmcclain/ldap"
"goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/constants"
)
func parseFilterForGroup(req api.ApiCoreGroupsListRequest, f *ber.Packet, skip bool) (api.ApiCoreGroupsListRequest, bool) {
func ParseFilterForGroup(req api.ApiCoreGroupsListRequest, f *ber.Packet, skip bool) (api.ApiCoreGroupsListRequest, bool) {
switch f.Tag {
case ldap.FilterEqualityMatch:
return parseFilterForGroupSingle(req, f)
case ldap.FilterAnd:
for _, child := range f.Children {
r, s := parseFilterForGroup(req, child, skip)
r, s := ParseFilterForGroup(req, child, skip)
skip = skip || s
req = r
}
@ -53,7 +54,7 @@ func parseFilterForGroupSingle(req api.ApiCoreGroupsListRequest, f *ber.Packet)
username := userDN.RDNs[0].Attributes[0].Value
// If the DN's first ou is virtual-groups, ignore this filter
if len(userDN.RDNs) > 1 {
if userDN.RDNs[1].Attributes[0].Value == VirtualGroupsOU || userDN.RDNs[1].Attributes[0].Value == GroupsOU {
if userDN.RDNs[1].Attributes[0].Value == constants.OUVirtualGroups || userDN.RDNs[1].Attributes[0].Value == constants.OUGroups {
// Since we know we're not filtering anything, skip this request
return req, true
}

View File

@ -1,19 +1,20 @@
package ldap
package utils
import (
goldap "github.com/go-ldap/ldap/v3"
ber "github.com/nmcclain/asn1-ber"
"github.com/nmcclain/ldap"
"goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/constants"
)
func parseFilterForUser(req api.ApiCoreUsersListRequest, f *ber.Packet, skip bool) (api.ApiCoreUsersListRequest, bool) {
func ParseFilterForUser(req api.ApiCoreUsersListRequest, f *ber.Packet, skip bool) (api.ApiCoreUsersListRequest, bool) {
switch f.Tag {
case ldap.FilterEqualityMatch:
return parseFilterForUserSingle(req, f)
case ldap.FilterAnd:
for _, child := range f.Children {
r, s := parseFilterForUser(req, child, skip)
r, s := ParseFilterForUser(req, child, skip)
skip = skip || s
req = r
}
@ -58,7 +59,7 @@ func parseFilterForUserSingle(req api.ApiCoreUsersListRequest, f *ber.Packet) (a
name := groupDN.RDNs[0].Attributes[0].Value
// If the DN's first ou is virtual-groups, ignore this filter
if len(groupDN.RDNs) > 1 {
if groupDN.RDNs[1].Attributes[0].Value == UsersOU || groupDN.RDNs[1].Attributes[0].Value == VirtualGroupsOU {
if groupDN.RDNs[1].Attributes[0].Value == constants.OUUsers || groupDN.RDNs[1].Attributes[0].Value == constants.OUVirtualGroups {
// Since we know we're not filtering anything, skip this request
return req, true
}