From 1b285f85c0666bc734618ee3dd889e7786b09850 Mon Sep 17 00:00:00 2001 From: "Jens L." Date: Mon, 29 Jul 2024 22:14:18 +0200 Subject: [PATCH] outposts: implement general paginator for list API requests (#10619) * outposts: implement general paginator Signed-off-by: Jens Langhammer * migrate LDAP Signed-off-by: Jens Langhammer * change main outpost refresh logic to use paginator everywhere Signed-off-by: Jens Langhammer * add comments to understand anything Signed-off-by: Jens Langhammer * actually use paginator everywhere Signed-off-by: Jens Langhammer --------- Signed-off-by: Jens Langhammer --- internal/outpost/ak/api_utils.go | 64 +++++++++++++++++++ internal/outpost/ldap/refresh.go | 12 ++-- internal/outpost/ldap/search/direct/direct.go | 20 ++++-- internal/outpost/ldap/search/memory/memory.go | 15 ++++- .../outpost/ldap/utils/paginator/paginator.go | 64 ------------------- internal/outpost/proxyv2/refresh.go | 9 ++- internal/outpost/radius/api.go | 12 ++-- internal/web/brand_tls/brand_tls.go | 22 ++++--- 8 files changed, 127 insertions(+), 91 deletions(-) create mode 100644 internal/outpost/ak/api_utils.go delete mode 100644 internal/outpost/ldap/utils/paginator/paginator.go diff --git a/internal/outpost/ak/api_utils.go b/internal/outpost/ak/api_utils.go new file mode 100644 index 0000000000..6a8a90280b --- /dev/null +++ b/internal/outpost/ak/api_utils.go @@ -0,0 +1,64 @@ +package ak + +import ( + "errors" + "net/http" + + log "github.com/sirupsen/logrus" + "goauthentik.io/api/v3" +) + +// Generic interface that mimics a generated request by the API client +// Requires mainly `Treq` which will be the actual request type, and +// `Tres` which is the response type +type PaginatorRequest[Treq any, Tres any] interface { + Page(page int32) Treq + PageSize(size int32) Treq + Execute() (Tres, *http.Response, error) +} + +// Generic interface that mimics a generated response by the API client +type PaginatorResponse[Tobj any] interface { + GetResults() []Tobj + GetPagination() api.Pagination +} + +// Paginator options for page size +type PaginatorOptions struct { + PageSize int + Logger *log.Entry +} + +// Automatically fetch all objects from an API endpoint using the pagination +// data received from the server. +func Paginator[Tobj any, Treq any, Tres PaginatorResponse[Tobj]]( + req PaginatorRequest[Treq, Tres], + opts PaginatorOptions, +) ([]Tobj, error) { + fetchOffset := func(page int32) (Tres, error) { + req.Page(page) + req.PageSize(int32(opts.PageSize)) + res, _, err := req.Execute() + if err != nil { + opts.Logger.WithError(err).WithField("page", page).Warning("failed to fetch page") + } + return res, err + } + var page int32 = 1 + errs := make([]error, 0) + objects := make([]Tobj, 0) + for { + apiObjects, err := fetchOffset(page) + if err != nil { + errs = append(errs, err) + continue + } + objects = append(objects, apiObjects.GetResults()...) + if apiObjects.GetPagination().Next > 0 { + page += 1 + } else { + break + } + } + return objects, errors.Join(errs...) +} diff --git a/internal/outpost/ldap/refresh.go b/internal/outpost/ldap/refresh.go index b58849bfe0..9f5dbc1496 100644 --- a/internal/outpost/ldap/refresh.go +++ b/internal/outpost/ldap/refresh.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" "goauthentik.io/api/v3" + "goauthentik.io/internal/outpost/ak" "goauthentik.io/internal/outpost/ldap/bind" directbind "goauthentik.io/internal/outpost/ldap/bind/direct" memorybind "goauthentik.io/internal/outpost/ldap/bind/memory" @@ -40,16 +41,19 @@ func (ls *LDAPServer) getInvalidationFlow() string { } func (ls *LDAPServer) Refresh() error { - outposts, _, err := ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()).Execute() + apiProviders, err := ak.Paginator(ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()), ak.PaginatorOptions{ + PageSize: 100, + Logger: ls.log, + }) if err != nil { return err } - if len(outposts.Results) < 1 { + if len(apiProviders) < 1 { return errors.New("no ldap provider defined") } - providers := make([]*ProviderInstance, len(outposts.Results)) + providers := make([]*ProviderInstance, len(apiProviders)) invalidationFlow := ls.getInvalidationFlow() - for idx, provider := range outposts.Results { + for idx, provider := range apiProviders { userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUUsers, *provider.BaseDn)) groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn)) virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUVirtualGroups, *provider.BaseDn)) diff --git a/internal/outpost/ldap/search/direct/direct.go b/internal/outpost/ldap/search/direct/direct.go index 1122bd611b..fdc52e4111 100644 --- a/internal/outpost/ldap/search/direct/direct.go +++ b/internal/outpost/ldap/search/direct/direct.go @@ -12,13 +12,13 @@ import ( "github.com/getsentry/sentry-go" "github.com/prometheus/client_golang/prometheus" "goauthentik.io/api/v3" + "goauthentik.io/internal/outpost/ak" "goauthentik.io/internal/outpost/ldap/constants" "goauthentik.io/internal/outpost/ldap/group" "goauthentik.io/internal/outpost/ldap/metrics" "goauthentik.io/internal/outpost/ldap/search" "goauthentik.io/internal/outpost/ldap/server" "goauthentik.io/internal/outpost/ldap/utils" - "goauthentik.io/internal/outpost/ldap/utils/paginator" ) type DirectSearcher struct { @@ -120,9 +120,14 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult, return nil } - u := paginator.FetchUsers(searchReq) + u, err := ak.Paginator(searchReq, ak.PaginatorOptions{ + PageSize: 100, + Logger: ds.log, + }) uapisp.Finish() - + if err != nil { + return err + } users = &u } else { if flags.UserInfo == nil { @@ -161,8 +166,14 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult, searchReq = searchReq.MembersByPk([]int32{flags.UserPk}) } - g := paginator.FetchGroups(searchReq) + g, err := ak.Paginator(searchReq, ak.PaginatorOptions{ + PageSize: 100, + Logger: ds.log, + }) gapisp.Finish() + if err != nil { + return err + } req.Log().WithField("count", len(g)).Trace("Got results from API") if !flags.CanSearch { @@ -177,7 +188,6 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult, } } } - groups = &g return nil }) diff --git a/internal/outpost/ldap/search/memory/memory.go b/internal/outpost/ldap/search/memory/memory.go index b889edba04..0236cd9f28 100644 --- a/internal/outpost/ldap/search/memory/memory.go +++ b/internal/outpost/ldap/search/memory/memory.go @@ -11,6 +11,7 @@ import ( "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" "goauthentik.io/api/v3" + "goauthentik.io/internal/outpost/ak" "goauthentik.io/internal/outpost/ldap/constants" "goauthentik.io/internal/outpost/ldap/flags" "goauthentik.io/internal/outpost/ldap/group" @@ -19,7 +20,6 @@ import ( "goauthentik.io/internal/outpost/ldap/search/direct" "goauthentik.io/internal/outpost/ldap/server" "goauthentik.io/internal/outpost/ldap/utils" - "goauthentik.io/internal/outpost/ldap/utils/paginator" ) type MemorySearcher struct { @@ -38,8 +38,17 @@ func NewMemorySearcher(si server.LDAPServerInstance) *MemorySearcher { ds: direct.NewDirectSearcher(si), } ms.log.Debug("initialised memory searcher") - ms.users = paginator.FetchUsers(ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).IncludeGroups(true)) - ms.groups = paginator.FetchGroups(ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()).IncludeUsers(true)) + // Error is not handled here, we get an empty/truncated list and the error is logged + users, _ := ak.Paginator(ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).IncludeGroups(true), ak.PaginatorOptions{ + PageSize: 100, + Logger: ms.log, + }) + ms.users = users + groups, _ := ak.Paginator(ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()).IncludeUsers(true), ak.PaginatorOptions{ + PageSize: 100, + Logger: ms.log, + }) + ms.groups = groups return ms } diff --git a/internal/outpost/ldap/utils/paginator/paginator.go b/internal/outpost/ldap/utils/paginator/paginator.go deleted file mode 100644 index f6793b7e51..0000000000 --- a/internal/outpost/ldap/utils/paginator/paginator.go +++ /dev/null @@ -1,64 +0,0 @@ -package paginator - -import ( - log "github.com/sirupsen/logrus" - "goauthentik.io/api/v3" -) - -const PageSize = 100 - -func FetchUsers(req api.ApiCoreUsersListRequest) []api.User { - fetchUsersOffset := func(page int) (*api.PaginatedUserList, error) { - users, _, err := req.Page(int32(page)).PageSize(PageSize).Execute() - if err != nil { - log.WithError(err).Warning("failed to update users") - return nil, err - } - log.WithField("page", page).WithField("count", len(users.Results)).Debug("fetched users") - return users, nil - } - page := 1 - users := make([]api.User, 0) - for { - apiUsers, err := fetchUsersOffset(page) - if err != nil { - log.WithError(err).WithField("page", page).Warn("Failed to fetch user page") - continue - } - users = append(users, apiUsers.Results...) - if apiUsers.Pagination.Next > 0 { - page += 1 - } else { - break - } - } - return users -} - -func FetchGroups(req api.ApiCoreGroupsListRequest) []api.Group { - fetchGroupsOffset := func(page int) (*api.PaginatedGroupList, error) { - groups, _, err := req.Page(int32(page)).PageSize(PageSize).Execute() - if err != nil { - log.WithError(err).Warning("failed to update groups") - return nil, err - } - log.WithField("page", page).WithField("count", len(groups.Results)).Debug("fetched groups") - return groups, nil - } - page := 1 - groups := make([]api.Group, 0) - for { - apiGroups, err := fetchGroupsOffset(page) - if err != nil { - log.WithError(err).WithField("page", page).Warn("Failed to fetch group page") - continue - } - groups = append(groups, apiGroups.Results...) - if apiGroups.Pagination.Next > 0 { - page += 1 - } else { - break - } - } - return groups -} diff --git a/internal/outpost/proxyv2/refresh.go b/internal/outpost/proxyv2/refresh.go index 357044c463..b1d1a71164 100644 --- a/internal/outpost/proxyv2/refresh.go +++ b/internal/outpost/proxyv2/refresh.go @@ -14,7 +14,10 @@ import ( ) func (ps *ProxyServer) Refresh() error { - providers, _, err := ps.akAPI.Client.OutpostsApi.OutpostsProxyList(context.Background()).Execute() + providers, err := ak.Paginator(ps.akAPI.Client.OutpostsApi.OutpostsProxyList(context.Background()), ak.PaginatorOptions{ + PageSize: 100, + Logger: ps.log, + }) if err != nil { ps.log.WithError(err).Error("Failed to fetch providers") } @@ -22,7 +25,7 @@ func (ps *ProxyServer) Refresh() error { return err } apps := make(map[string]*application.Application) - for _, provider := range providers.Results { + for _, provider := range providers { rsp := sentry.StartSpan(context.Background(), "authentik.outposts.proxy.application_ss") ua := fmt.Sprintf(" (provider=%s)", provider.Name) hc := &http.Client{ @@ -35,7 +38,7 @@ func (ps *ProxyServer) Refresh() error { ), } a, err := application.NewApplication(provider, hc, ps) - existing, ok := apps[a.Host] + existing, ok := ps.apps[a.Host] if ok { existing.Stop() } diff --git a/internal/outpost/radius/api.go b/internal/outpost/radius/api.go index ab6f275abc..947fb7bf94 100644 --- a/internal/outpost/radius/api.go +++ b/internal/outpost/radius/api.go @@ -8,6 +8,7 @@ import ( "strings" log "github.com/sirupsen/logrus" + "goauthentik.io/internal/outpost/ak" ) func parseCIDRs(raw string) []*net.IPNet { @@ -30,15 +31,18 @@ func parseCIDRs(raw string) []*net.IPNet { } func (rs *RadiusServer) Refresh() error { - outposts, _, err := rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()).Execute() + apiProviders, err := ak.Paginator(rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()), ak.PaginatorOptions{ + PageSize: 100, + Logger: rs.log, + }) if err != nil { return err } - if len(outposts.Results) < 1 { + if len(apiProviders) < 1 { return errors.New("no radius provider defined") } - providers := make([]*ProviderInstance, len(outposts.Results)) - for idx, provider := range outposts.Results { + providers := make([]*ProviderInstance, len(apiProviders)) + for idx, provider := range apiProviders { logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name) providers[idx] = &ProviderInstance{ SharedSecret: []byte(provider.GetSharedSecret()), diff --git a/internal/web/brand_tls/brand_tls.go b/internal/web/brand_tls/brand_tls.go index 7f19e33eb6..107cf904c5 100644 --- a/internal/web/brand_tls/brand_tls.go +++ b/internal/web/brand_tls/brand_tls.go @@ -1,6 +1,7 @@ package brand_tls import ( + "context" "crypto/tls" "strings" "time" @@ -46,20 +47,25 @@ func (w *Watcher) Start() { func (w *Watcher) Check() { w.log.Info("updating brand certificates") - brands, _, err := w.client.CoreApi.CoreBrandsListExecute(api.ApiCoreBrandsListRequest{}) + brands, err := ak.Paginator(w.client.CoreApi.CoreBrandsList(context.Background()), ak.PaginatorOptions{ + PageSize: 100, + Logger: w.log, + }) if err != nil { w.log.WithError(err).Warning("failed to get brands") return } - for _, t := range brands.Results { - if kp := t.WebCertificate.Get(); kp != nil { - err := w.cs.AddKeypair(*kp) - if err != nil { - w.log.WithError(err).Warning("failed to add certificate") - } + for _, b := range brands { + kp := b.WebCertificate.Get() + if kp == nil { + continue + } + err := w.cs.AddKeypair(*kp) + if err != nil { + w.log.WithError(err).Warning("failed to add certificate") } } - w.brands = brands.Results + w.brands = brands } func (w *Watcher) GetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {