outposts: implement general paginator for list API requests (#10619)
* outposts: implement general paginator Signed-off-by: Jens Langhammer <jens@goauthentik.io> * migrate LDAP Signed-off-by: Jens Langhammer <jens@goauthentik.io> * change main outpost refresh logic to use paginator everywhere Signed-off-by: Jens Langhammer <jens@goauthentik.io> * add comments to understand anything Signed-off-by: Jens Langhammer <jens@goauthentik.io> * actually use paginator everywhere Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
64
internal/outpost/ak/api_utils.go
Normal file
64
internal/outpost/ak/api_utils.go
Normal file
@ -0,0 +1,64 @@
|
||||
package ak
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/api/v3"
|
||||
)
|
||||
|
||||
// Generic interface that mimics a generated request by the API client
|
||||
// Requires mainly `Treq` which will be the actual request type, and
|
||||
// `Tres` which is the response type
|
||||
type PaginatorRequest[Treq any, Tres any] interface {
|
||||
Page(page int32) Treq
|
||||
PageSize(size int32) Treq
|
||||
Execute() (Tres, *http.Response, error)
|
||||
}
|
||||
|
||||
// Generic interface that mimics a generated response by the API client
|
||||
type PaginatorResponse[Tobj any] interface {
|
||||
GetResults() []Tobj
|
||||
GetPagination() api.Pagination
|
||||
}
|
||||
|
||||
// Paginator options for page size
|
||||
type PaginatorOptions struct {
|
||||
PageSize int
|
||||
Logger *log.Entry
|
||||
}
|
||||
|
||||
// Automatically fetch all objects from an API endpoint using the pagination
|
||||
// data received from the server.
|
||||
func Paginator[Tobj any, Treq any, Tres PaginatorResponse[Tobj]](
|
||||
req PaginatorRequest[Treq, Tres],
|
||||
opts PaginatorOptions,
|
||||
) ([]Tobj, error) {
|
||||
fetchOffset := func(page int32) (Tres, error) {
|
||||
req.Page(page)
|
||||
req.PageSize(int32(opts.PageSize))
|
||||
res, _, err := req.Execute()
|
||||
if err != nil {
|
||||
opts.Logger.WithError(err).WithField("page", page).Warning("failed to fetch page")
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
var page int32 = 1
|
||||
errs := make([]error, 0)
|
||||
objects := make([]Tobj, 0)
|
||||
for {
|
||||
apiObjects, err := fetchOffset(page)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
objects = append(objects, apiObjects.GetResults()...)
|
||||
if apiObjects.GetPagination().Next > 0 {
|
||||
page += 1
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return objects, errors.Join(errs...)
|
||||
}
|
@ -11,6 +11,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"goauthentik.io/api/v3"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
"goauthentik.io/internal/outpost/ldap/bind"
|
||||
directbind "goauthentik.io/internal/outpost/ldap/bind/direct"
|
||||
memorybind "goauthentik.io/internal/outpost/ldap/bind/memory"
|
||||
@ -40,16 +41,19 @@ func (ls *LDAPServer) getInvalidationFlow() string {
|
||||
}
|
||||
|
||||
func (ls *LDAPServer) Refresh() error {
|
||||
outposts, _, err := ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()).Execute()
|
||||
apiProviders, err := ak.Paginator(ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()), ak.PaginatorOptions{
|
||||
PageSize: 100,
|
||||
Logger: ls.log,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(outposts.Results) < 1 {
|
||||
if len(apiProviders) < 1 {
|
||||
return errors.New("no ldap provider defined")
|
||||
}
|
||||
providers := make([]*ProviderInstance, len(outposts.Results))
|
||||
providers := make([]*ProviderInstance, len(apiProviders))
|
||||
invalidationFlow := ls.getInvalidationFlow()
|
||||
for idx, provider := range outposts.Results {
|
||||
for idx, provider := range apiProviders {
|
||||
userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUUsers, *provider.BaseDn))
|
||||
groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn))
|
||||
virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUVirtualGroups, *provider.BaseDn))
|
||||
|
@ -12,13 +12,13 @@ import (
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"goauthentik.io/api/v3"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
"goauthentik.io/internal/outpost/ldap/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/group"
|
||||
"goauthentik.io/internal/outpost/ldap/metrics"
|
||||
"goauthentik.io/internal/outpost/ldap/search"
|
||||
"goauthentik.io/internal/outpost/ldap/server"
|
||||
"goauthentik.io/internal/outpost/ldap/utils"
|
||||
"goauthentik.io/internal/outpost/ldap/utils/paginator"
|
||||
)
|
||||
|
||||
type DirectSearcher struct {
|
||||
@ -120,9 +120,14 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
return nil
|
||||
}
|
||||
|
||||
u := paginator.FetchUsers(searchReq)
|
||||
u, err := ak.Paginator(searchReq, ak.PaginatorOptions{
|
||||
PageSize: 100,
|
||||
Logger: ds.log,
|
||||
})
|
||||
uapisp.Finish()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
users = &u
|
||||
} else {
|
||||
if flags.UserInfo == nil {
|
||||
@ -161,8 +166,14 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
searchReq = searchReq.MembersByPk([]int32{flags.UserPk})
|
||||
}
|
||||
|
||||
g := paginator.FetchGroups(searchReq)
|
||||
g, err := ak.Paginator(searchReq, ak.PaginatorOptions{
|
||||
PageSize: 100,
|
||||
Logger: ds.log,
|
||||
})
|
||||
gapisp.Finish()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Log().WithField("count", len(g)).Trace("Got results from API")
|
||||
|
||||
if !flags.CanSearch {
|
||||
@ -177,7 +188,6 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
groups = &g
|
||||
return nil
|
||||
})
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/api/v3"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
"goauthentik.io/internal/outpost/ldap/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
"goauthentik.io/internal/outpost/ldap/group"
|
||||
@ -19,7 +20,6 @@ import (
|
||||
"goauthentik.io/internal/outpost/ldap/search/direct"
|
||||
"goauthentik.io/internal/outpost/ldap/server"
|
||||
"goauthentik.io/internal/outpost/ldap/utils"
|
||||
"goauthentik.io/internal/outpost/ldap/utils/paginator"
|
||||
)
|
||||
|
||||
type MemorySearcher struct {
|
||||
@ -38,8 +38,17 @@ func NewMemorySearcher(si server.LDAPServerInstance) *MemorySearcher {
|
||||
ds: direct.NewDirectSearcher(si),
|
||||
}
|
||||
ms.log.Debug("initialised memory searcher")
|
||||
ms.users = paginator.FetchUsers(ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).IncludeGroups(true))
|
||||
ms.groups = paginator.FetchGroups(ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()).IncludeUsers(true))
|
||||
// Error is not handled here, we get an empty/truncated list and the error is logged
|
||||
users, _ := ak.Paginator(ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).IncludeGroups(true), ak.PaginatorOptions{
|
||||
PageSize: 100,
|
||||
Logger: ms.log,
|
||||
})
|
||||
ms.users = users
|
||||
groups, _ := ak.Paginator(ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()).IncludeUsers(true), ak.PaginatorOptions{
|
||||
PageSize: 100,
|
||||
Logger: ms.log,
|
||||
})
|
||||
ms.groups = groups
|
||||
return ms
|
||||
}
|
||||
|
||||
|
@ -1,64 +0,0 @@
|
||||
package paginator
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/api/v3"
|
||||
)
|
||||
|
||||
const PageSize = 100
|
||||
|
||||
func FetchUsers(req api.ApiCoreUsersListRequest) []api.User {
|
||||
fetchUsersOffset := func(page int) (*api.PaginatedUserList, error) {
|
||||
users, _, err := req.Page(int32(page)).PageSize(PageSize).Execute()
|
||||
if err != nil {
|
||||
log.WithError(err).Warning("failed to update users")
|
||||
return nil, err
|
||||
}
|
||||
log.WithField("page", page).WithField("count", len(users.Results)).Debug("fetched users")
|
||||
return users, nil
|
||||
}
|
||||
page := 1
|
||||
users := make([]api.User, 0)
|
||||
for {
|
||||
apiUsers, err := fetchUsersOffset(page)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("page", page).Warn("Failed to fetch user page")
|
||||
continue
|
||||
}
|
||||
users = append(users, apiUsers.Results...)
|
||||
if apiUsers.Pagination.Next > 0 {
|
||||
page += 1
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return users
|
||||
}
|
||||
|
||||
func FetchGroups(req api.ApiCoreGroupsListRequest) []api.Group {
|
||||
fetchGroupsOffset := func(page int) (*api.PaginatedGroupList, error) {
|
||||
groups, _, err := req.Page(int32(page)).PageSize(PageSize).Execute()
|
||||
if err != nil {
|
||||
log.WithError(err).Warning("failed to update groups")
|
||||
return nil, err
|
||||
}
|
||||
log.WithField("page", page).WithField("count", len(groups.Results)).Debug("fetched groups")
|
||||
return groups, nil
|
||||
}
|
||||
page := 1
|
||||
groups := make([]api.Group, 0)
|
||||
for {
|
||||
apiGroups, err := fetchGroupsOffset(page)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("page", page).Warn("Failed to fetch group page")
|
||||
continue
|
||||
}
|
||||
groups = append(groups, apiGroups.Results...)
|
||||
if apiGroups.Pagination.Next > 0 {
|
||||
page += 1
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
return groups
|
||||
}
|
@ -14,7 +14,10 @@ import (
|
||||
)
|
||||
|
||||
func (ps *ProxyServer) Refresh() error {
|
||||
providers, _, err := ps.akAPI.Client.OutpostsApi.OutpostsProxyList(context.Background()).Execute()
|
||||
providers, err := ak.Paginator(ps.akAPI.Client.OutpostsApi.OutpostsProxyList(context.Background()), ak.PaginatorOptions{
|
||||
PageSize: 100,
|
||||
Logger: ps.log,
|
||||
})
|
||||
if err != nil {
|
||||
ps.log.WithError(err).Error("Failed to fetch providers")
|
||||
}
|
||||
@ -22,7 +25,7 @@ func (ps *ProxyServer) Refresh() error {
|
||||
return err
|
||||
}
|
||||
apps := make(map[string]*application.Application)
|
||||
for _, provider := range providers.Results {
|
||||
for _, provider := range providers {
|
||||
rsp := sentry.StartSpan(context.Background(), "authentik.outposts.proxy.application_ss")
|
||||
ua := fmt.Sprintf(" (provider=%s)", provider.Name)
|
||||
hc := &http.Client{
|
||||
@ -35,7 +38,7 @@ func (ps *ProxyServer) Refresh() error {
|
||||
),
|
||||
}
|
||||
a, err := application.NewApplication(provider, hc, ps)
|
||||
existing, ok := apps[a.Host]
|
||||
existing, ok := ps.apps[a.Host]
|
||||
if ok {
|
||||
existing.Stop()
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
)
|
||||
|
||||
func parseCIDRs(raw string) []*net.IPNet {
|
||||
@ -30,15 +31,18 @@ func parseCIDRs(raw string) []*net.IPNet {
|
||||
}
|
||||
|
||||
func (rs *RadiusServer) Refresh() error {
|
||||
outposts, _, err := rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()).Execute()
|
||||
apiProviders, err := ak.Paginator(rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()), ak.PaginatorOptions{
|
||||
PageSize: 100,
|
||||
Logger: rs.log,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(outposts.Results) < 1 {
|
||||
if len(apiProviders) < 1 {
|
||||
return errors.New("no radius provider defined")
|
||||
}
|
||||
providers := make([]*ProviderInstance, len(outposts.Results))
|
||||
for idx, provider := range outposts.Results {
|
||||
providers := make([]*ProviderInstance, len(apiProviders))
|
||||
for idx, provider := range apiProviders {
|
||||
logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name)
|
||||
providers[idx] = &ProviderInstance{
|
||||
SharedSecret: []byte(provider.GetSharedSecret()),
|
||||
|
@ -1,6 +1,7 @@
|
||||
package brand_tls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"strings"
|
||||
"time"
|
||||
@ -46,20 +47,25 @@ func (w *Watcher) Start() {
|
||||
|
||||
func (w *Watcher) Check() {
|
||||
w.log.Info("updating brand certificates")
|
||||
brands, _, err := w.client.CoreApi.CoreBrandsListExecute(api.ApiCoreBrandsListRequest{})
|
||||
brands, err := ak.Paginator(w.client.CoreApi.CoreBrandsList(context.Background()), ak.PaginatorOptions{
|
||||
PageSize: 100,
|
||||
Logger: w.log,
|
||||
})
|
||||
if err != nil {
|
||||
w.log.WithError(err).Warning("failed to get brands")
|
||||
return
|
||||
}
|
||||
for _, t := range brands.Results {
|
||||
if kp := t.WebCertificate.Get(); kp != nil {
|
||||
err := w.cs.AddKeypair(*kp)
|
||||
if err != nil {
|
||||
w.log.WithError(err).Warning("failed to add certificate")
|
||||
}
|
||||
for _, b := range brands {
|
||||
kp := b.WebCertificate.Get()
|
||||
if kp == nil {
|
||||
continue
|
||||
}
|
||||
err := w.cs.AddKeypair(*kp)
|
||||
if err != nil {
|
||||
w.log.WithError(err).Warning("failed to add certificate")
|
||||
}
|
||||
}
|
||||
w.brands = brands.Results
|
||||
w.brands = brands
|
||||
}
|
||||
|
||||
func (w *Watcher) GetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
|
Reference in New Issue
Block a user