providers/ldap: correctly use pagination in search results in both modes (#5492)
closes #4292 Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		@ -18,6 +18,7 @@ import (
 | 
			
		||||
	"goauthentik.io/internal/outpost/ldap/search"
 | 
			
		||||
	"goauthentik.io/internal/outpost/ldap/server"
 | 
			
		||||
	"goauthentik.io/internal/outpost/ldap/utils"
 | 
			
		||||
	"goauthentik.io/internal/outpost/ldap/utils/paginator"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type DirectSearcher struct {
 | 
			
		||||
@ -124,15 +125,10 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
 | 
			
		||||
					return nil
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				u, _, err := searchReq.Execute()
 | 
			
		||||
				u := paginator.FetchUsers(searchReq)
 | 
			
		||||
				uapisp.Finish()
 | 
			
		||||
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					req.Log().WithError(err).Warning("failed to get users")
 | 
			
		||||
					return err
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				users = &u.Results
 | 
			
		||||
				users = &u
 | 
			
		||||
			} else {
 | 
			
		||||
				if flags.UserInfo == nil {
 | 
			
		||||
					uapisp := sentry.StartSpan(errCtx, "authentik.providers.ldap.search.api_user")
 | 
			
		||||
@ -170,29 +166,24 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
 | 
			
		||||
				searchReq = searchReq.MembersByPk([]int32{flags.UserPk})
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			g, _, err := searchReq.Execute()
 | 
			
		||||
			g := paginator.FetchGroups(searchReq)
 | 
			
		||||
			gapisp.Finish()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				req.Log().WithError(err).Warning("failed to get groups")
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			req.Log().WithField("count", len(g.Results)).Trace("Got results from API")
 | 
			
		||||
			req.Log().WithField("count", len(g)).Trace("Got results from API")
 | 
			
		||||
 | 
			
		||||
			if !flags.CanSearch {
 | 
			
		||||
				for i, results := range g.Results {
 | 
			
		||||
				for i, results := range g {
 | 
			
		||||
					// If they can't search, remove any users from the group results except the one we're looking for.
 | 
			
		||||
					g.Results[i].Users = []int32{flags.UserPk}
 | 
			
		||||
					g[i].Users = []int32{flags.UserPk}
 | 
			
		||||
					for _, u := range results.UsersObj {
 | 
			
		||||
						if u.Pk == flags.UserPk {
 | 
			
		||||
							g.Results[i].UsersObj = []api.GroupMember{u}
 | 
			
		||||
							g[i].UsersObj = []api.GroupMember{u}
 | 
			
		||||
							break
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			groups = &g.Results
 | 
			
		||||
 | 
			
		||||
			groups = &g
 | 
			
		||||
			return nil
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -1,63 +0,0 @@
 | 
			
		||||
package memory
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
 | 
			
		||||
	"goauthentik.io/api/v3"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const pageSize = 100
 | 
			
		||||
 | 
			
		||||
func (ms *MemorySearcher) FetchUsers() []api.User {
 | 
			
		||||
	fetchUsersOffset := func(page int) (*api.PaginatedUserList, error) {
 | 
			
		||||
		users, _, err := ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).Page(int32(page)).PageSize(pageSize).Execute()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			ms.log.WithError(err).Warning("failed to update users")
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		ms.log.WithField("page", page).WithField("count", len(users.Results)).Debug("fetched users")
 | 
			
		||||
		return users, nil
 | 
			
		||||
	}
 | 
			
		||||
	page := 1
 | 
			
		||||
	users := make([]api.User, 0)
 | 
			
		||||
	for {
 | 
			
		||||
		apiUsers, err := fetchUsersOffset(page)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return users
 | 
			
		||||
		}
 | 
			
		||||
		users = append(users, apiUsers.Results...)
 | 
			
		||||
		if apiUsers.Pagination.Next > 0 {
 | 
			
		||||
			page += 1
 | 
			
		||||
		} else {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return users
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ms *MemorySearcher) FetchGroups() []api.Group {
 | 
			
		||||
	fetchGroupsOffset := func(page int) (*api.PaginatedGroupList, error) {
 | 
			
		||||
		groups, _, err := ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()).Page(int32(page)).PageSize(pageSize).Execute()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			ms.log.WithError(err).Warning("failed to update groups")
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		ms.log.WithField("page", page).WithField("count", len(groups.Results)).Debug("fetched groups")
 | 
			
		||||
		return groups, nil
 | 
			
		||||
	}
 | 
			
		||||
	page := 1
 | 
			
		||||
	groups := make([]api.Group, 0)
 | 
			
		||||
	for {
 | 
			
		||||
		apiGroups, err := fetchGroupsOffset(page)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return groups
 | 
			
		||||
		}
 | 
			
		||||
		groups = append(groups, apiGroups.Results...)
 | 
			
		||||
		if apiGroups.Pagination.Next > 0 {
 | 
			
		||||
			page += 1
 | 
			
		||||
		} else {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return groups
 | 
			
		||||
}
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
package memory
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
@ -16,6 +17,7 @@ import (
 | 
			
		||||
	"goauthentik.io/internal/outpost/ldap/search"
 | 
			
		||||
	"goauthentik.io/internal/outpost/ldap/server"
 | 
			
		||||
	"goauthentik.io/internal/outpost/ldap/utils"
 | 
			
		||||
	"goauthentik.io/internal/outpost/ldap/utils/paginator"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type MemorySearcher struct {
 | 
			
		||||
@ -32,8 +34,8 @@ func NewMemorySearcher(si server.LDAPServerInstance) *MemorySearcher {
 | 
			
		||||
		log: log.WithField("logger", "authentik.outpost.ldap.searcher.memory"),
 | 
			
		||||
	}
 | 
			
		||||
	ms.log.Debug("initialised memory searcher")
 | 
			
		||||
	ms.users = ms.FetchUsers()
 | 
			
		||||
	ms.groups = ms.FetchGroups()
 | 
			
		||||
	ms.users = paginator.FetchUsers(ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()))
 | 
			
		||||
	ms.groups = paginator.FetchGroups(ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()))
 | 
			
		||||
	return ms
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										64
									
								
								internal/outpost/ldap/utils/paginator/paginator.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								internal/outpost/ldap/utils/paginator/paginator.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,64 @@
 | 
			
		||||
package paginator
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
	"goauthentik.io/api/v3"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const PageSize = 100
 | 
			
		||||
 | 
			
		||||
func FetchUsers(req api.ApiCoreUsersListRequest) []api.User {
 | 
			
		||||
	fetchUsersOffset := func(page int) (*api.PaginatedUserList, error) {
 | 
			
		||||
		users, _, err := req.Page(int32(page)).PageSize(PageSize).Execute()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.WithError(err).Warning("failed to update users")
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		log.WithField("page", page).WithField("count", len(users.Results)).Debug("fetched users")
 | 
			
		||||
		return users, nil
 | 
			
		||||
	}
 | 
			
		||||
	page := 1
 | 
			
		||||
	users := make([]api.User, 0)
 | 
			
		||||
	for {
 | 
			
		||||
		apiUsers, err := fetchUsersOffset(page)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.WithError(err).WithField("page", page).Warn("Failed to fetch user page")
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		users = append(users, apiUsers.Results...)
 | 
			
		||||
		if apiUsers.Pagination.Next > 0 {
 | 
			
		||||
			page += 1
 | 
			
		||||
		} else {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return users
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func FetchGroups(req api.ApiCoreGroupsListRequest) []api.Group {
 | 
			
		||||
	fetchGroupsOffset := func(page int) (*api.PaginatedGroupList, error) {
 | 
			
		||||
		groups, _, err := req.Page(int32(page)).PageSize(PageSize).Execute()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.WithError(err).Warning("failed to update groups")
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		log.WithField("page", page).WithField("count", len(groups.Results)).Debug("fetched groups")
 | 
			
		||||
		return groups, nil
 | 
			
		||||
	}
 | 
			
		||||
	page := 1
 | 
			
		||||
	groups := make([]api.Group, 0)
 | 
			
		||||
	for {
 | 
			
		||||
		apiGroups, err := fetchGroupsOffset(page)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.WithError(err).WithField("page", page).Warn("Failed to fetch group page")
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		groups = append(groups, apiGroups.Results...)
 | 
			
		||||
		if apiGroups.Pagination.Next > 0 {
 | 
			
		||||
			page += 1
 | 
			
		||||
		} else {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return groups
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user