outposts: implement general paginator for list API requests (#10619)
* outposts: implement general paginator Signed-off-by: Jens Langhammer <jens@goauthentik.io> * migrate LDAP Signed-off-by: Jens Langhammer <jens@goauthentik.io> * change main outpost refresh logic to use paginator everywhere Signed-off-by: Jens Langhammer <jens@goauthentik.io> * add comments to understand anything Signed-off-by: Jens Langhammer <jens@goauthentik.io> * actually use paginator everywhere Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
64
internal/outpost/ak/api_utils.go
Normal file
64
internal/outpost/ak/api_utils.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
package ak
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"goauthentik.io/api/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Generic interface that mimics a generated request by the API client
|
||||||
|
// Requires mainly `Treq` which will be the actual request type, and
|
||||||
|
// `Tres` which is the response type
|
||||||
|
type PaginatorRequest[Treq any, Tres any] interface {
|
||||||
|
Page(page int32) Treq
|
||||||
|
PageSize(size int32) Treq
|
||||||
|
Execute() (Tres, *http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generic interface that mimics a generated response by the API client
|
||||||
|
type PaginatorResponse[Tobj any] interface {
|
||||||
|
GetResults() []Tobj
|
||||||
|
GetPagination() api.Pagination
|
||||||
|
}
|
||||||
|
|
||||||
|
// Paginator options for page size
|
||||||
|
type PaginatorOptions struct {
|
||||||
|
PageSize int
|
||||||
|
Logger *log.Entry
|
||||||
|
}
|
||||||
|
|
||||||
|
// Automatically fetch all objects from an API endpoint using the pagination
|
||||||
|
// data received from the server.
|
||||||
|
func Paginator[Tobj any, Treq any, Tres PaginatorResponse[Tobj]](
|
||||||
|
req PaginatorRequest[Treq, Tres],
|
||||||
|
opts PaginatorOptions,
|
||||||
|
) ([]Tobj, error) {
|
||||||
|
fetchOffset := func(page int32) (Tres, error) {
|
||||||
|
req.Page(page)
|
||||||
|
req.PageSize(int32(opts.PageSize))
|
||||||
|
res, _, err := req.Execute()
|
||||||
|
if err != nil {
|
||||||
|
opts.Logger.WithError(err).WithField("page", page).Warning("failed to fetch page")
|
||||||
|
}
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
var page int32 = 1
|
||||||
|
errs := make([]error, 0)
|
||||||
|
objects := make([]Tobj, 0)
|
||||||
|
for {
|
||||||
|
apiObjects, err := fetchOffset(page)
|
||||||
|
if err != nil {
|
||||||
|
errs = append(errs, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
objects = append(objects, apiObjects.GetResults()...)
|
||||||
|
if apiObjects.GetPagination().Next > 0 {
|
||||||
|
page += 1
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return objects, errors.Join(errs...)
|
||||||
|
}
|
||||||
@ -11,6 +11,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"goauthentik.io/api/v3"
|
"goauthentik.io/api/v3"
|
||||||
|
"goauthentik.io/internal/outpost/ak"
|
||||||
"goauthentik.io/internal/outpost/ldap/bind"
|
"goauthentik.io/internal/outpost/ldap/bind"
|
||||||
directbind "goauthentik.io/internal/outpost/ldap/bind/direct"
|
directbind "goauthentik.io/internal/outpost/ldap/bind/direct"
|
||||||
memorybind "goauthentik.io/internal/outpost/ldap/bind/memory"
|
memorybind "goauthentik.io/internal/outpost/ldap/bind/memory"
|
||||||
@ -40,16 +41,19 @@ func (ls *LDAPServer) getInvalidationFlow() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ls *LDAPServer) Refresh() error {
|
func (ls *LDAPServer) Refresh() error {
|
||||||
outposts, _, err := ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()).Execute()
|
apiProviders, err := ak.Paginator(ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()), ak.PaginatorOptions{
|
||||||
|
PageSize: 100,
|
||||||
|
Logger: ls.log,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(outposts.Results) < 1 {
|
if len(apiProviders) < 1 {
|
||||||
return errors.New("no ldap provider defined")
|
return errors.New("no ldap provider defined")
|
||||||
}
|
}
|
||||||
providers := make([]*ProviderInstance, len(outposts.Results))
|
providers := make([]*ProviderInstance, len(apiProviders))
|
||||||
invalidationFlow := ls.getInvalidationFlow()
|
invalidationFlow := ls.getInvalidationFlow()
|
||||||
for idx, provider := range outposts.Results {
|
for idx, provider := range apiProviders {
|
||||||
userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUUsers, *provider.BaseDn))
|
userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUUsers, *provider.BaseDn))
|
||||||
groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn))
|
groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn))
|
||||||
virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUVirtualGroups, *provider.BaseDn))
|
virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUVirtualGroups, *provider.BaseDn))
|
||||||
|
|||||||
@ -12,13 +12,13 @@ import (
|
|||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"goauthentik.io/api/v3"
|
"goauthentik.io/api/v3"
|
||||||
|
"goauthentik.io/internal/outpost/ak"
|
||||||
"goauthentik.io/internal/outpost/ldap/constants"
|
"goauthentik.io/internal/outpost/ldap/constants"
|
||||||
"goauthentik.io/internal/outpost/ldap/group"
|
"goauthentik.io/internal/outpost/ldap/group"
|
||||||
"goauthentik.io/internal/outpost/ldap/metrics"
|
"goauthentik.io/internal/outpost/ldap/metrics"
|
||||||
"goauthentik.io/internal/outpost/ldap/search"
|
"goauthentik.io/internal/outpost/ldap/search"
|
||||||
"goauthentik.io/internal/outpost/ldap/server"
|
"goauthentik.io/internal/outpost/ldap/server"
|
||||||
"goauthentik.io/internal/outpost/ldap/utils"
|
"goauthentik.io/internal/outpost/ldap/utils"
|
||||||
"goauthentik.io/internal/outpost/ldap/utils/paginator"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type DirectSearcher struct {
|
type DirectSearcher struct {
|
||||||
@ -120,9 +120,14 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
u := paginator.FetchUsers(searchReq)
|
u, err := ak.Paginator(searchReq, ak.PaginatorOptions{
|
||||||
|
PageSize: 100,
|
||||||
|
Logger: ds.log,
|
||||||
|
})
|
||||||
uapisp.Finish()
|
uapisp.Finish()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
users = &u
|
users = &u
|
||||||
} else {
|
} else {
|
||||||
if flags.UserInfo == nil {
|
if flags.UserInfo == nil {
|
||||||
@ -161,8 +166,14 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
|||||||
searchReq = searchReq.MembersByPk([]int32{flags.UserPk})
|
searchReq = searchReq.MembersByPk([]int32{flags.UserPk})
|
||||||
}
|
}
|
||||||
|
|
||||||
g := paginator.FetchGroups(searchReq)
|
g, err := ak.Paginator(searchReq, ak.PaginatorOptions{
|
||||||
|
PageSize: 100,
|
||||||
|
Logger: ds.log,
|
||||||
|
})
|
||||||
gapisp.Finish()
|
gapisp.Finish()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
req.Log().WithField("count", len(g)).Trace("Got results from API")
|
req.Log().WithField("count", len(g)).Trace("Got results from API")
|
||||||
|
|
||||||
if !flags.CanSearch {
|
if !flags.CanSearch {
|
||||||
@ -177,7 +188,6 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
groups = &g
|
groups = &g
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"goauthentik.io/api/v3"
|
"goauthentik.io/api/v3"
|
||||||
|
"goauthentik.io/internal/outpost/ak"
|
||||||
"goauthentik.io/internal/outpost/ldap/constants"
|
"goauthentik.io/internal/outpost/ldap/constants"
|
||||||
"goauthentik.io/internal/outpost/ldap/flags"
|
"goauthentik.io/internal/outpost/ldap/flags"
|
||||||
"goauthentik.io/internal/outpost/ldap/group"
|
"goauthentik.io/internal/outpost/ldap/group"
|
||||||
@ -19,7 +20,6 @@ import (
|
|||||||
"goauthentik.io/internal/outpost/ldap/search/direct"
|
"goauthentik.io/internal/outpost/ldap/search/direct"
|
||||||
"goauthentik.io/internal/outpost/ldap/server"
|
"goauthentik.io/internal/outpost/ldap/server"
|
||||||
"goauthentik.io/internal/outpost/ldap/utils"
|
"goauthentik.io/internal/outpost/ldap/utils"
|
||||||
"goauthentik.io/internal/outpost/ldap/utils/paginator"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type MemorySearcher struct {
|
type MemorySearcher struct {
|
||||||
@ -38,8 +38,17 @@ func NewMemorySearcher(si server.LDAPServerInstance) *MemorySearcher {
|
|||||||
ds: direct.NewDirectSearcher(si),
|
ds: direct.NewDirectSearcher(si),
|
||||||
}
|
}
|
||||||
ms.log.Debug("initialised memory searcher")
|
ms.log.Debug("initialised memory searcher")
|
||||||
ms.users = paginator.FetchUsers(ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).IncludeGroups(true))
|
// Error is not handled here, we get an empty/truncated list and the error is logged
|
||||||
ms.groups = paginator.FetchGroups(ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()).IncludeUsers(true))
|
users, _ := ak.Paginator(ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).IncludeGroups(true), ak.PaginatorOptions{
|
||||||
|
PageSize: 100,
|
||||||
|
Logger: ms.log,
|
||||||
|
})
|
||||||
|
ms.users = users
|
||||||
|
groups, _ := ak.Paginator(ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()).IncludeUsers(true), ak.PaginatorOptions{
|
||||||
|
PageSize: 100,
|
||||||
|
Logger: ms.log,
|
||||||
|
})
|
||||||
|
ms.groups = groups
|
||||||
return ms
|
return ms
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,64 +0,0 @@
|
|||||||
package paginator
|
|
||||||
|
|
||||||
import (
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"goauthentik.io/api/v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
const PageSize = 100
|
|
||||||
|
|
||||||
func FetchUsers(req api.ApiCoreUsersListRequest) []api.User {
|
|
||||||
fetchUsersOffset := func(page int) (*api.PaginatedUserList, error) {
|
|
||||||
users, _, err := req.Page(int32(page)).PageSize(PageSize).Execute()
|
|
||||||
if err != nil {
|
|
||||||
log.WithError(err).Warning("failed to update users")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
log.WithField("page", page).WithField("count", len(users.Results)).Debug("fetched users")
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
page := 1
|
|
||||||
users := make([]api.User, 0)
|
|
||||||
for {
|
|
||||||
apiUsers, err := fetchUsersOffset(page)
|
|
||||||
if err != nil {
|
|
||||||
log.WithError(err).WithField("page", page).Warn("Failed to fetch user page")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
users = append(users, apiUsers.Results...)
|
|
||||||
if apiUsers.Pagination.Next > 0 {
|
|
||||||
page += 1
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return users
|
|
||||||
}
|
|
||||||
|
|
||||||
func FetchGroups(req api.ApiCoreGroupsListRequest) []api.Group {
|
|
||||||
fetchGroupsOffset := func(page int) (*api.PaginatedGroupList, error) {
|
|
||||||
groups, _, err := req.Page(int32(page)).PageSize(PageSize).Execute()
|
|
||||||
if err != nil {
|
|
||||||
log.WithError(err).Warning("failed to update groups")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
log.WithField("page", page).WithField("count", len(groups.Results)).Debug("fetched groups")
|
|
||||||
return groups, nil
|
|
||||||
}
|
|
||||||
page := 1
|
|
||||||
groups := make([]api.Group, 0)
|
|
||||||
for {
|
|
||||||
apiGroups, err := fetchGroupsOffset(page)
|
|
||||||
if err != nil {
|
|
||||||
log.WithError(err).WithField("page", page).Warn("Failed to fetch group page")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
groups = append(groups, apiGroups.Results...)
|
|
||||||
if apiGroups.Pagination.Next > 0 {
|
|
||||||
page += 1
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return groups
|
|
||||||
}
|
|
||||||
@ -14,7 +14,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (ps *ProxyServer) Refresh() error {
|
func (ps *ProxyServer) Refresh() error {
|
||||||
providers, _, err := ps.akAPI.Client.OutpostsApi.OutpostsProxyList(context.Background()).Execute()
|
providers, err := ak.Paginator(ps.akAPI.Client.OutpostsApi.OutpostsProxyList(context.Background()), ak.PaginatorOptions{
|
||||||
|
PageSize: 100,
|
||||||
|
Logger: ps.log,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ps.log.WithError(err).Error("Failed to fetch providers")
|
ps.log.WithError(err).Error("Failed to fetch providers")
|
||||||
}
|
}
|
||||||
@ -22,7 +25,7 @@ func (ps *ProxyServer) Refresh() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
apps := make(map[string]*application.Application)
|
apps := make(map[string]*application.Application)
|
||||||
for _, provider := range providers.Results {
|
for _, provider := range providers {
|
||||||
rsp := sentry.StartSpan(context.Background(), "authentik.outposts.proxy.application_ss")
|
rsp := sentry.StartSpan(context.Background(), "authentik.outposts.proxy.application_ss")
|
||||||
ua := fmt.Sprintf(" (provider=%s)", provider.Name)
|
ua := fmt.Sprintf(" (provider=%s)", provider.Name)
|
||||||
hc := &http.Client{
|
hc := &http.Client{
|
||||||
@ -35,7 +38,7 @@ func (ps *ProxyServer) Refresh() error {
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
a, err := application.NewApplication(provider, hc, ps)
|
a, err := application.NewApplication(provider, hc, ps)
|
||||||
existing, ok := apps[a.Host]
|
existing, ok := ps.apps[a.Host]
|
||||||
if ok {
|
if ok {
|
||||||
existing.Stop()
|
existing.Stop()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"goauthentik.io/internal/outpost/ak"
|
||||||
)
|
)
|
||||||
|
|
||||||
func parseCIDRs(raw string) []*net.IPNet {
|
func parseCIDRs(raw string) []*net.IPNet {
|
||||||
@ -30,15 +31,18 @@ func parseCIDRs(raw string) []*net.IPNet {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (rs *RadiusServer) Refresh() error {
|
func (rs *RadiusServer) Refresh() error {
|
||||||
outposts, _, err := rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()).Execute()
|
apiProviders, err := ak.Paginator(rs.ac.Client.OutpostsApi.OutpostsRadiusList(context.Background()), ak.PaginatorOptions{
|
||||||
|
PageSize: 100,
|
||||||
|
Logger: rs.log,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(outposts.Results) < 1 {
|
if len(apiProviders) < 1 {
|
||||||
return errors.New("no radius provider defined")
|
return errors.New("no radius provider defined")
|
||||||
}
|
}
|
||||||
providers := make([]*ProviderInstance, len(outposts.Results))
|
providers := make([]*ProviderInstance, len(apiProviders))
|
||||||
for idx, provider := range outposts.Results {
|
for idx, provider := range apiProviders {
|
||||||
logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name)
|
logger := log.WithField("logger", "authentik.outpost.radius").WithField("provider", provider.Name)
|
||||||
providers[idx] = &ProviderInstance{
|
providers[idx] = &ProviderInstance{
|
||||||
SharedSecret: []byte(provider.GetSharedSecret()),
|
SharedSecret: []byte(provider.GetSharedSecret()),
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package brand_tls
|
package brand_tls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -46,20 +47,25 @@ func (w *Watcher) Start() {
|
|||||||
|
|
||||||
func (w *Watcher) Check() {
|
func (w *Watcher) Check() {
|
||||||
w.log.Info("updating brand certificates")
|
w.log.Info("updating brand certificates")
|
||||||
brands, _, err := w.client.CoreApi.CoreBrandsListExecute(api.ApiCoreBrandsListRequest{})
|
brands, err := ak.Paginator(w.client.CoreApi.CoreBrandsList(context.Background()), ak.PaginatorOptions{
|
||||||
|
PageSize: 100,
|
||||||
|
Logger: w.log,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.log.WithError(err).Warning("failed to get brands")
|
w.log.WithError(err).Warning("failed to get brands")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, t := range brands.Results {
|
for _, b := range brands {
|
||||||
if kp := t.WebCertificate.Get(); kp != nil {
|
kp := b.WebCertificate.Get()
|
||||||
err := w.cs.AddKeypair(*kp)
|
if kp == nil {
|
||||||
if err != nil {
|
continue
|
||||||
w.log.WithError(err).Warning("failed to add certificate")
|
}
|
||||||
}
|
err := w.cs.AddKeypair(*kp)
|
||||||
|
if err != nil {
|
||||||
|
w.log.WithError(err).Warning("failed to add certificate")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
w.brands = brands.Results
|
w.brands = brands
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Watcher) GetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
func (w *Watcher) GetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user