Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-03-27 18:54:38 +01:00
133 changed files with 2886 additions and 2218 deletions

View File

@ -162,13 +162,14 @@ func (c *Config) parseScheme(rawVal string) string {
if err != nil {
return rawVal
}
if u.Scheme == "env" {
switch u.Scheme {
case "env":
e, ok := os.LookupEnv(u.Host)
if ok {
return e
}
return u.RawQuery
} else if u.Scheme == "file" {
case "file":
d, err := os.ReadFile(u.Path)
if err != nil {
return u.RawQuery

View File

@ -10,7 +10,7 @@ import (
)
func TestConfigEnv(t *testing.T) {
os.Setenv("AUTHENTIK_SECRET_KEY", "bar")
assert.NoError(t, os.Setenv("AUTHENTIK_SECRET_KEY", "bar"))
cfg = nil
if err := Get().fromEnv(); err != nil {
panic(err)
@ -19,8 +19,8 @@ func TestConfigEnv(t *testing.T) {
}
func TestConfigEnv_Scheme(t *testing.T) {
os.Setenv("foo", "bar")
os.Setenv("AUTHENTIK_SECRET_KEY", "env://foo")
assert.NoError(t, os.Setenv("foo", "bar"))
assert.NoError(t, os.Setenv("AUTHENTIK_SECRET_KEY", "env://foo"))
cfg = nil
if err := Get().fromEnv(); err != nil {
panic(err)
@ -33,13 +33,15 @@ func TestConfigEnv_File(t *testing.T) {
if err != nil {
log.Fatal(err)
}
defer os.Remove(file.Name())
defer func() {
assert.NoError(t, os.Remove(file.Name()))
}()
_, err = file.Write([]byte("bar"))
if err != nil {
panic(err)
}
os.Setenv("AUTHENTIK_SECRET_KEY", fmt.Sprintf("file://%s", file.Name()))
assert.NoError(t, os.Setenv("AUTHENTIK_SECRET_KEY", fmt.Sprintf("file://%s", file.Name())))
cfg = nil
if err := Get().fromEnv(); err != nil {
panic(err)

View File

@ -35,7 +35,7 @@ func EnableDebugServer() {
if err != nil {
return nil
}
_, err = w.Write([]byte(fmt.Sprintf("<a href='%[1]s'>%[1]s</a><br>", tpl)))
_, err = fmt.Fprintf(w, "<a href='%[1]s'>%[1]s</a><br>", tpl)
if err != nil {
l.WithError(err).Warning("failed to write index")
return nil

View File

@ -44,10 +44,11 @@ func New(healthcheck func() bool) *GoUnicorn {
signal.Notify(c, syscall.SIGHUP, syscall.SIGUSR2)
go func() {
for sig := range c {
if sig == syscall.SIGHUP {
switch sig {
case syscall.SIGHUP:
g.log.Info("SIGHUP received, forwarding to gunicorn")
g.Reload()
} else if sig == syscall.SIGUSR2 {
case syscall.SIGUSR2:
g.log.Info("SIGUSR2 received, restarting gunicorn")
g.Restart()
}

View File

@ -35,13 +35,19 @@ func Paginator[Tobj any, Treq any, Tres PaginatorResponse[Tobj]](
req PaginatorRequest[Treq, Tres],
opts PaginatorOptions,
) ([]Tobj, error) {
if opts.Logger == nil {
opts.Logger = log.NewEntry(log.StandardLogger())
}
var bfreq, cfreq interface{}
fetchOffset := func(page int32) (Tres, error) {
bfreq = req.Page(page)
cfreq = bfreq.(PaginatorRequest[Treq, Tres]).PageSize(int32(opts.PageSize))
res, _, err := cfreq.(PaginatorRequest[Treq, Tres]).Execute()
res, hres, err := cfreq.(PaginatorRequest[Treq, Tres]).Execute()
if err != nil {
opts.Logger.WithError(err).WithField("page", page).Warning("failed to fetch page")
if hres != nil && hres.StatusCode >= 400 && hres.StatusCode < 500 {
return res, err
}
}
return res, err
}
@ -51,6 +57,9 @@ func Paginator[Tobj any, Treq any, Tres PaginatorResponse[Tobj]](
for {
apiObjects, err := fetchOffset(page)
if err != nil {
if page == 1 {
return objects, err
}
errs = append(errs, err)
continue
}

View File

@ -1,5 +1,64 @@
package ak
import (
"errors"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"goauthentik.io/api/v3"
)
type fakeAPIType struct{}
type fakeAPIResponse struct {
results []fakeAPIType
pagination api.Pagination
}
func (fapi *fakeAPIResponse) GetResults() []fakeAPIType { return fapi.results }
func (fapi *fakeAPIResponse) GetPagination() api.Pagination { return fapi.pagination }
type fakeAPIRequest struct {
res *fakeAPIResponse
http *http.Response
err error
}
func (fapi *fakeAPIRequest) Page(page int32) *fakeAPIRequest { return fapi }
func (fapi *fakeAPIRequest) PageSize(size int32) *fakeAPIRequest { return fapi }
func (fapi *fakeAPIRequest) Execute() (*fakeAPIResponse, *http.Response, error) {
return fapi.res, fapi.http, fapi.err
}
func Test_Simple(t *testing.T) {
req := &fakeAPIRequest{
res: &fakeAPIResponse{
results: []fakeAPIType{
{},
},
pagination: api.Pagination{
TotalPages: 1,
},
},
}
res, err := Paginator(req, PaginatorOptions{})
assert.NoError(t, err)
assert.Len(t, res, 1)
}
func Test_BadRequest(t *testing.T) {
req := &fakeAPIRequest{
http: &http.Response{
StatusCode: 400,
},
err: errors.New("foo"),
}
res, err := Paginator(req, PaginatorOptions{})
assert.Error(t, err)
assert.Equal(t, []fakeAPIType{}, res)
}
// func Test_PaginatorCompile(t *testing.T) {
// req := api.ApiCoreUsersListRequest{}
// Paginator(req, PaginatorOptions{

View File

@ -148,7 +148,8 @@ func (ac *APIController) startWSHandler() {
"outpost_type": ac.Server.Type(),
"uuid": ac.instanceUUID.String(),
}).Set(1)
if wsMsg.Instruction == WebsocketInstructionTriggerUpdate {
switch wsMsg.Instruction {
case WebsocketInstructionTriggerUpdate:
time.Sleep(ac.reloadOffset)
logger.Debug("Got update trigger...")
err := ac.OnRefresh()
@ -163,7 +164,7 @@ func (ac *APIController) startWSHandler() {
"build": constants.BUILD(""),
}).SetToCurrentTime()
}
} else if wsMsg.Instruction == WebsocketInstructionProviderSpecific {
case WebsocketInstructionProviderSpecific:
for _, h := range ac.wsHandlers {
h(context.Background(), wsMsg.Args)
}

View File

@ -66,7 +66,12 @@ func (ls *LDAPServer) StartLDAPServer() error {
return err
}
proxyListener := &proxyproto.Listener{Listener: ln, ConnPolicy: utils.GetProxyConnectionPolicy()}
defer proxyListener.Close()
defer func() {
err := proxyListener.Close()
if err != nil {
ls.log.WithError(err).Warning("failed to close proxy listener")
}
}()
ls.log.WithField("listen", listen).Info("Starting LDAP server")
err = ls.s.Serve(proxyListener)

View File

@ -49,7 +49,12 @@ func (ls *LDAPServer) StartLDAPTLSServer() error {
}
proxyListener := &proxyproto.Listener{Listener: ln, ConnPolicy: utils.GetProxyConnectionPolicy()}
defer proxyListener.Close()
defer func() {
err := proxyListener.Close()
if err != nil {
ls.log.WithError(err).Warning("failed to close proxy listener")
}
}()
tln := tls.NewListener(proxyListener, tlsConfig)

View File

@ -98,7 +98,7 @@ func (ms *MemorySearcher) Search(req *search.Request) (ldap.ServerSearchResult,
entries := make([]*ldap.Entry, 0)
scope := req.SearchRequest.Scope
scope := req.Scope
needUsers, needGroups := ms.si.GetNeededObjects(scope, req.BaseDN, req.FilterObjectClass)
if scope >= 0 && strings.EqualFold(req.BaseDN, baseDN) {

View File

@ -56,7 +56,7 @@ func GetOIDCEndpoint(p api.ProxyOutpostConfig, authentikHost string, embedded bo
if !embedded && hostBrowser == "" {
return ep
}
var newHost *url.URL = aku
var newHost = aku
var newBrowserHost *url.URL
if embedded {
if authentikHost == "" {

View File

@ -130,7 +130,12 @@ func (ps *ProxyServer) ServeHTTP() {
return
}
proxyListener := &proxyproto.Listener{Listener: listener, ConnPolicy: utils.GetProxyConnectionPolicy()}
defer proxyListener.Close()
defer func() {
err := proxyListener.Close()
if err != nil {
ps.log.WithError(err).Warning("failed to close proxy listener")
}
}()
ps.log.WithField("listen", listenAddress).Info("Starting HTTP server")
ps.serve(proxyListener)
@ -149,7 +154,12 @@ func (ps *ProxyServer) ServeHTTPS() {
return
}
proxyListener := &proxyproto.Listener{Listener: web.TCPKeepAliveListener{TCPListener: ln.(*net.TCPListener)}, ConnPolicy: utils.GetProxyConnectionPolicy()}
defer proxyListener.Close()
defer func() {
err := proxyListener.Close()
if err != nil {
ps.log.WithError(err).Warning("failed to close proxy listener")
}
}()
tlsListener := tls.NewListener(proxyListener, tlsConfig)
ps.log.WithField("listen", listenAddress).Info("Starting HTTPS server")

View File

@ -72,11 +72,13 @@ func (s *RedisStore) New(r *http.Request, name string) (*sessions.Session, error
session.ID = c.Value
err = s.load(r.Context(), session)
if err == nil {
session.IsNew = false
} else if err == redis.Nil {
err = nil // no data stored
if err != nil {
if errors.Is(err, redis.Nil) {
return session, nil
}
return session, err
}
session.IsNew = false
return session, err
}

View File

@ -158,7 +158,12 @@ func (ws *WebServer) listenPlain() {
return
}
proxyListener := &proxyproto.Listener{Listener: ln, ConnPolicy: utils.GetProxyConnectionPolicy()}
defer proxyListener.Close()
defer func() {
err := proxyListener.Close()
if err != nil {
ws.log.WithError(err).Warning("failed to close proxy listener")
}
}()
ws.log.WithField("listen", config.Get().Listen.HTTP).Info("Starting HTTP server")
ws.serve(proxyListener)

View File

@ -46,7 +46,12 @@ func (ws *WebServer) listenTLS() {
return
}
proxyListener := &proxyproto.Listener{Listener: web.TCPKeepAliveListener{TCPListener: ln.(*net.TCPListener)}, ConnPolicy: utils.GetProxyConnectionPolicy()}
defer proxyListener.Close()
defer func() {
err := proxyListener.Close()
if err != nil {
ws.log.WithError(err).Warning("failed to close proxy listener")
}
}()
tlsListener := tls.NewListener(proxyListener, tlsConfig)
ws.log.WithField("listen", config.Get().Listen.HTTPS).Info("Starting HTTPS server")