From 57f25a97c956136cf9695dfe9cadd8774a0f4675 Mon Sep 17 00:00:00 2001 From: "Jens L." Date: Wed, 28 May 2025 13:43:35 +0200 Subject: [PATCH] providers/ldap: retain binder and update users instead of re-creating (#14735) Signed-off-by: Jens Langhammer --- internal/outpost/ldap/bind/memory/memory.go | 20 ++++++++++--------- internal/outpost/ldap/refresh.go | 7 ++++++- internal/outpost/ldap/search/memory/memory.go | 16 +++++++++++++-- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/internal/outpost/ldap/bind/memory/memory.go b/internal/outpost/ldap/bind/memory/memory.go index ef5eb56a85..97080cf582 100644 --- a/internal/outpost/ldap/bind/memory/memory.go +++ b/internal/outpost/ldap/bind/memory/memory.go @@ -28,16 +28,18 @@ func NewSessionBinder(si server.LDAPServerInstance, oldBinder bind.Binder) *Sess si: si, log: log.WithField("logger", "authentik.outpost.ldap.binder.session"), } - if oldSb, ok := oldBinder.(*SessionBinder); ok { - sb.DirectBinder = oldSb.DirectBinder - sb.sessions = oldSb.sessions - sb.log.Debug("re-initialised session binder") - } else { - sb.sessions = ttlcache.New(ttlcache.WithDisableTouchOnHit[Credentials, ldap.LDAPResultCode]()) - sb.DirectBinder = *direct.NewDirectBinder(si) - go sb.sessions.Start() - sb.log.Debug("initialised session binder") + if oldBinder != nil { + if oldSb, ok := oldBinder.(*SessionBinder); ok { + sb.DirectBinder = oldSb.DirectBinder + sb.sessions = oldSb.sessions + sb.log.Debug("re-initialised session binder") + return sb + } } + sb.sessions = ttlcache.New(ttlcache.WithDisableTouchOnHit[Credentials, ldap.LDAPResultCode]()) + sb.DirectBinder = *direct.NewDirectBinder(si) + go sb.sessions.Start() + sb.log.Debug("initialised session binder") return sb } diff --git a/internal/outpost/ldap/refresh.go b/internal/outpost/ldap/refresh.go index 0f00bbeb26..2ad73ab661 100644 --- a/internal/outpost/ldap/refresh.go +++ b/internal/outpost/ldap/refresh.go @@ -16,6 +16,7 @@ import ( memorybind "goauthentik.io/internal/outpost/ldap/bind/memory" "goauthentik.io/internal/outpost/ldap/constants" "goauthentik.io/internal/outpost/ldap/flags" + "goauthentik.io/internal/outpost/ldap/search" directsearch "goauthentik.io/internal/outpost/ldap/search/direct" memorysearch "goauthentik.io/internal/outpost/ldap/search/memory" ) @@ -85,7 +86,11 @@ func (ls *LDAPServer) Refresh() error { providers[idx].certUUID = *kp } if *provider.SearchMode.Ptr() == api.LDAPAPIACCESSMODE_CACHED { - providers[idx].searcher = memorysearch.NewMemorySearcher(providers[idx]) + var oldSearcher search.Searcher + if existing != nil { + oldSearcher = existing.searcher + } + providers[idx].searcher = memorysearch.NewMemorySearcher(providers[idx], oldSearcher) } else if *provider.SearchMode.Ptr() == api.LDAPAPIACCESSMODE_DIRECT { providers[idx].searcher = directsearch.NewDirectSearcher(providers[idx]) } diff --git a/internal/outpost/ldap/search/memory/memory.go b/internal/outpost/ldap/search/memory/memory.go index c4f23a60e8..509b98e824 100644 --- a/internal/outpost/ldap/search/memory/memory.go +++ b/internal/outpost/ldap/search/memory/memory.go @@ -31,13 +31,26 @@ type MemorySearcher struct { groups []api.Group } -func NewMemorySearcher(si server.LDAPServerInstance) *MemorySearcher { +func NewMemorySearcher(si server.LDAPServerInstance, existing search.Searcher) *MemorySearcher { ms := &MemorySearcher{ si: si, log: log.WithField("logger", "authentik.outpost.ldap.searcher.memory"), ds: direct.NewDirectSearcher(si), } + if existing != nil { + if ems, ok := existing.(*MemorySearcher); ok { + ems.si = si + ems.fetch() + ems.log.Debug("re-initialised memory searcher") + return ems + } + } + ms.fetch() ms.log.Debug("initialised memory searcher") + return ms +} + +func (ms *MemorySearcher) fetch() { // Error is not handled here, we get an empty/truncated list and the error is logged users, _ := ak.Paginator(ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).IncludeGroups(true), ak.PaginatorOptions{ PageSize: 100, @@ -49,7 +62,6 @@ func NewMemorySearcher(si server.LDAPServerInstance) *MemorySearcher { Logger: ms.log, }) ms.groups = groups - return ms } func (ms *MemorySearcher) SearchBase(req *search.Request) (ldap.ServerSearchResult, error) {