outposts: implement general paginator for list API requests (#10619)

* outposts: implement general paginator

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* migrate LDAP

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* change main outpost refresh logic to use paginator everywhere

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add comments to understand anything

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* actually use paginator everywhere

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L.
2024-07-29 22:14:18 +02:00
committed by GitHub
parent d79ac0e5bc
commit 1b285f85c0
8 changed files with 127 additions and 91 deletions

View File

@ -0,0 +1,64 @@
package ak
import (
"errors"
"net/http"
log "github.com/sirupsen/logrus"
"goauthentik.io/api/v3"
)
// Generic interface that mimics a generated request by the API client
// Requires mainly `Treq` which will be the actual request type, and
// `Tres` which is the response type
type PaginatorRequest[Treq any, Tres any] interface {
Page(page int32) Treq
PageSize(size int32) Treq
Execute() (Tres, *http.Response, error)
}
// Generic interface that mimics a generated response by the API client
type PaginatorResponse[Tobj any] interface {
GetResults() []Tobj
GetPagination() api.Pagination
}
// Paginator options for page size
type PaginatorOptions struct {
PageSize int
Logger *log.Entry
}
// Automatically fetch all objects from an API endpoint using the pagination
// data received from the server.
func Paginator[Tobj any, Treq any, Tres PaginatorResponse[Tobj]](
req PaginatorRequest[Treq, Tres],
opts PaginatorOptions,
) ([]Tobj, error) {
fetchOffset := func(page int32) (Tres, error) {
req.Page(page)
req.PageSize(int32(opts.PageSize))
res, _, err := req.Execute()
if err != nil {
opts.Logger.WithError(err).WithField("page", page).Warning("failed to fetch page")
}
return res, err
}
var page int32 = 1
errs := make([]error, 0)
objects := make([]Tobj, 0)
for {
apiObjects, err := fetchOffset(page)
if err != nil {
errs = append(errs, err)
continue
}
objects = append(objects, apiObjects.GetResults()...)
if apiObjects.GetPagination().Next > 0 {
page += 1
} else {
break
}
}
return objects, errors.Join(errs...)
}

View File

@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus"
"goauthentik.io/api/v3"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/ldap/bind"
directbind "goauthentik.io/internal/outpost/ldap/bind/direct"
memorybind "goauthentik.io/internal/outpost/ldap/bind/memory"
@ -40,16 +41,19 @@ func (ls *LDAPServer) getInvalidationFlow() string {
}
func (ls *LDAPServer) Refresh() error {
outposts, _, err := ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()).Execute()
apiProviders, err := ak.Paginator(ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()), ak.PaginatorOptions{
PageSize: 100,
Logger: ls.log,
})
if err != nil {
return err
}
if len(outposts.Results) < 1 {
if len(apiProviders) < 1 {
return errors.New("no ldap provider defined")
}
providers := make([]*ProviderInstance, len(outposts.Results))
providers := make([]*ProviderInstance, len(apiProviders))
invalidationFlow := ls.getInvalidationFlow()
for idx, provider := range outposts.Results {
for idx, provider := range apiProviders {
userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUUsers, *provider.BaseDn))
groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn))
virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUVirtualGroups, *provider.BaseDn))

View File

@ -12,13 +12,13 @@ import (
"github.com/getsentry/sentry-go"
"github.com/prometheus/client_golang/prometheus"
"goauthentik.io/api/v3"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/group"
"goauthentik.io/internal/outpost/ldap/metrics"
"goauthentik.io/internal/outpost/ldap/search"
"goauthentik.io/internal/outpost/ldap/server"
"goauthentik.io/internal/outpost/ldap/utils"
"goauthentik.io/internal/outpost/ldap/utils/paginator"
)
type DirectSearcher struct {
@ -120,9 +120,14 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
return nil
}
u := paginator.FetchUsers(searchReq)
u, err := ak.Paginator(searchReq, ak.PaginatorOptions{
PageSize: 100,
Logger: ds.log,
})
uapisp.Finish()
if err != nil {
return err
}
users = &u
} else {
if flags.UserInfo == nil {
@ -161,8 +166,14 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
searchReq = searchReq.MembersByPk([]int32{flags.UserPk})
}
g := paginator.FetchGroups(searchReq)
g, err := ak.Paginator(searchReq, ak.PaginatorOptions{
PageSize: 100,
Logger: ds.log,
})
gapisp.Finish()
if err != nil {
return err
}
req.Log().WithField("count", len(g)).Trace("Got results from API")
if !flags.CanSearch {
@ -177,7 +188,6 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
}
}
}
groups = &g
return nil
})

View File

@ -11,6 +11,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"goauthentik.io/api/v3"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/flags"
"goauthentik.io/internal/outpost/ldap/group"
@ -19,7 +20,6 @@ import (
"goauthentik.io/internal/outpost/ldap/search/direct"
"goauthentik.io/internal/outpost/ldap/server"
"goauthentik.io/internal/outpost/ldap/utils"
"goauthentik.io/internal/outpost/ldap/utils/paginator"
)
type MemorySearcher struct {
@ -38,8 +38,17 @@ func NewMemorySearcher(si server.LDAPServerInstance) *MemorySearcher {
ds: direct.NewDirectSearcher(si),
}
ms.log.Debug("initialised memory searcher")
ms.users = paginator.FetchUsers(ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).IncludeGroups(true))
ms.groups = paginator.FetchGroups(ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()).IncludeUsers(true))
// Error is not handled here, we get an empty/truncated list and the error is logged
users, _ := ak.Paginator(ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).IncludeGroups(true), ak.PaginatorOptions{
PageSize: 100,
Logger: ms.log,
})
ms.users = users
groups, _ := ak.Paginator(ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()).IncludeUsers(true), ak.PaginatorOptions{
PageSize: 100,
Logger: ms.log,
})
ms.groups = groups
return ms
}

View File

@ -1,64 +0,0 @@
package paginator
import (
log "github.com/sirupsen/logrus"
"goauthentik.io/api/v3"
)
const PageSize = 100
func FetchUsers(req api.ApiCoreUsersListRequest) []api.User {
fetchUsersOffset := func(page int) (*api.PaginatedUserList, error) {
users, _, err := req.Page(int32(page)).PageSize(PageSize).Execute()
if err != nil {
log.WithError(err).Warning("failed to update users")
return nil, err
}
log.WithField("page", page).WithField("count", len(users.Results)).Debug("fetched users")
return users, nil
}
page := 1
users := make([]api.User, 0)
for {
apiUsers, err := fetchUsersOffset(page)
if err != nil {
log.WithError(err).WithField("page", page).Warn("Failed to fetch user page")
continue
}
users = append(users, apiUsers.Results...)
if apiUsers.Pagination.Next > 0 {
page += 1
} else {
break
}
}
return users
}
func FetchGroups(req api.ApiCoreGroupsListRequest) []api.Group {
fetchGroupsOffset := func(page int) (*api.PaginatedGroupList, error) {
groups, _, err := req.Page(int32(page)).PageSize(PageSize).Execute()
if err != nil {
log.WithError(err).Warning("failed to update groups")
return nil, err
}
log.WithField("page", page).WithField("count", len(groups.Results)).Debug("fetched groups")
return groups, nil
}
page := 1
groups := make([]api.Group, 0)
for {
apiGroups, err := fetchGroupsOffset(page)
if err != nil {
log.WithError(err).WithField("page", page).Warn("Failed to fetch group page")
continue
}
groups = append(groups, apiGroups.Results...)
if apiGroups.Pagination.Next > 0 {
page += 1
} else {
break
}
}
return groups
}

View File

@ -14,7 +14,10 @@ import (
)
func (ps *ProxyServer) Refresh() error {
providers, _, err := ps.akAPI.Client.OutpostsApi.OutpostsProxyList(context.Background()).Execute()
providers, err := ak.Paginator(ps.akAPI.Client.OutpostsApi.OutpostsProxyList(context.Background()), ak.PaginatorOptions{
PageSize: 100,
Logger: ps.log,
})
if err != nil {
ps.log.WithError(err).Error("Failed to fetch providers")
}
@ -22,7 +25,7 @@ func (ps *ProxyServer) Refresh() error {
return err
}
apps := make(map[string]*application.Application)
for _, provider := range providers.Results {
for _, provider := range providers {
rsp := sentry.StartSpan(context.Background(), "authentik.outposts.proxy.application_ss")
ua := fmt.Sprintf(" (provider=%s)", provider.Name)
hc := &http.Client{
@ -35,7 +38,7 @@ func (ps *ProxyServer) Refresh() error {
),
}
a, err := application.NewApplication(provider, hc, ps)
existing, ok := apps[a.Host]
existing, ok := ps.apps[a.Host]
if ok {
existing.Stop()
}

View File

@ -8,6 +8,7 @@ import (
"strings"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/ak"
)
func parseCIDRs(raw string) []*net.IPNet {
@ -30,15 +31,18 @@ func parseCIDRs(raw string) []*net.IPNet {
}
func (rs *RadiusServer) Refresh() error {
outposts, _, err := rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()).Execute()
apiProviders, err := ak.Paginator(rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()), ak.PaginatorOptions{
PageSize: 100,
Logger: rs.log,
})
if err != nil {
return err
}
if len(outposts.Results) < 1 {
if len(apiProviders) < 1 {
return errors.New("no radius provider defined")
}
providers := make([]*ProviderInstance, len(outposts.Results))
for idx, provider := range outposts.Results {
providers := make([]*ProviderInstance, len(apiProviders))
for idx, provider := range apiProviders {
logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name)
providers[idx] = &ProviderInstance{
SharedSecret: []byte(provider.GetSharedSecret()),

View File

@ -1,6 +1,7 @@
package brand_tls
import (
"context"
"crypto/tls"
"strings"
"time"
@ -46,20 +47,25 @@ func (w *Watcher) Start() {
func (w *Watcher) Check() {
w.log.Info("updating brand certificates")
brands, _, err := w.client.CoreApi.CoreBrandsListExecute(api.ApiCoreBrandsListRequest{})
brands, err := ak.Paginator(w.client.CoreApi.CoreBrandsList(context.Background()), ak.PaginatorOptions{
PageSize: 100,
Logger: w.log,
})
if err != nil {
w.log.WithError(err).Warning("failed to get brands")
return
}
for _, t := range brands.Results {
if kp := t.WebCertificate.Get(); kp != nil {
err := w.cs.AddKeypair(*kp)
if err != nil {
w.log.WithError(err).Warning("failed to add certificate")
}
for _, b := range brands {
kp := b.WebCertificate.Get()
if kp == nil {
continue
}
err := w.cs.AddKeypair(*kp)
if err != nil {
w.log.WithError(err).Warning("failed to add certificate")
}
}
w.brands = brands.Results
w.brands = brands
}
func (w *Watcher) GetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {