diff --git a/internal/outpost/ak/api_utils.go b/internal/outpost/ak/api_utils.go index d732ba7556..9b477903ae 100644 --- a/internal/outpost/ak/api_utils.go +++ b/internal/outpost/ak/api_utils.go @@ -35,13 +35,19 @@ func Paginator[Tobj any, Treq any, Tres PaginatorResponse[Tobj]]( req PaginatorRequest[Treq, Tres], opts PaginatorOptions, ) ([]Tobj, error) { + if opts.Logger == nil { + opts.Logger = log.NewEntry(log.StandardLogger()) + } var bfreq, cfreq interface{} fetchOffset := func(page int32) (Tres, error) { bfreq = req.Page(page) cfreq = bfreq.(PaginatorRequest[Treq, Tres]).PageSize(int32(opts.PageSize)) - res, _, err := cfreq.(PaginatorRequest[Treq, Tres]).Execute() + res, hres, err := cfreq.(PaginatorRequest[Treq, Tres]).Execute() if err != nil { opts.Logger.WithError(err).WithField("page", page).Warning("failed to fetch page") + if hres != nil && hres.StatusCode >= 400 && hres.StatusCode < 500 { + return res, err + } } return res, err } @@ -51,6 +57,9 @@ func Paginator[Tobj any, Treq any, Tres PaginatorResponse[Tobj]]( for { apiObjects, err := fetchOffset(page) if err != nil { + if page == 1 { + return objects, err + } errs = append(errs, err) continue } diff --git a/internal/outpost/ak/api_utils_test.go b/internal/outpost/ak/api_utils_test.go index ac751f943a..7ee84e13f4 100644 --- a/internal/outpost/ak/api_utils_test.go +++ b/internal/outpost/ak/api_utils_test.go @@ -1,5 +1,64 @@ package ak +import ( + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "goauthentik.io/api/v3" +) + +type fakeAPIType struct{} + +type fakeAPIResponse struct { + results []fakeAPIType + pagination api.Pagination +} + +func (fapi *fakeAPIResponse) GetResults() []fakeAPIType { return fapi.results } +func (fapi *fakeAPIResponse) GetPagination() api.Pagination { return fapi.pagination } + +type fakeAPIRequest struct { + res *fakeAPIResponse + http *http.Response + err error +} + +func (fapi *fakeAPIRequest) Page(page int32) *fakeAPIRequest { return fapi } +func (fapi *fakeAPIRequest) PageSize(size int32) *fakeAPIRequest { return fapi } +func (fapi *fakeAPIRequest) Execute() (*fakeAPIResponse, *http.Response, error) { + return fapi.res, fapi.http, fapi.err +} + +func Test_Simple(t *testing.T) { + req := &fakeAPIRequest{ + res: &fakeAPIResponse{ + results: []fakeAPIType{ + {}, + }, + pagination: api.Pagination{ + TotalPages: 1, + }, + }, + } + res, err := Paginator(req, PaginatorOptions{}) + assert.NoError(t, err) + assert.Len(t, res, 1) +} + +func Test_BadRequest(t *testing.T) { + req := &fakeAPIRequest{ + http: &http.Response{ + StatusCode: 400, + }, + err: errors.New("foo"), + } + res, err := Paginator(req, PaginatorOptions{}) + assert.Error(t, err) + assert.Equal(t, []fakeAPIType{}, res) +} + // func Test_PaginatorCompile(t *testing.T) { // req := api.ApiCoreUsersListRequest{} // Paginator(req, PaginatorOptions{