Merge branch 'main' into celery-2-dramatiq

This commit is contained in:
Marc 'risson' Schmitt
2025-05-26 18:29:26 +02:00
682 changed files with 40181 additions and 20687 deletions

View File

@ -3,6 +3,7 @@ package brand_tls
import (
"context"
"crypto/tls"
"crypto/x509"
"strings"
"time"
@ -56,22 +57,37 @@ func (w *Watcher) Check() {
return
}
for _, b := range brands {
kp := b.WebCertificate.Get()
if kp == nil {
continue
kp := b.GetWebCertificate()
if kp != "" {
err := w.cs.AddKeypair(kp)
if err != nil {
w.log.WithError(err).WithField("kp", kp).Warning("failed to add web certificate")
}
}
err := w.cs.AddKeypair(*kp)
if err != nil {
w.log.WithError(err).Warning("failed to add certificate")
for _, crt := range b.GetClientCertificates() {
if crt != "" {
err := w.cs.AddKeypair(crt)
if err != nil {
w.log.WithError(err).WithField("kp", kp).Warning("failed to add client certificate")
}
}
}
}
w.brands = brands
}
func (w *Watcher) GetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
type CertificateConfig struct {
Web *tls.Certificate
Client *x509.CertPool
}
func (w *Watcher) GetCertificate(ch *tls.ClientHelloInfo) *CertificateConfig {
var bestSelection *api.Brand
config := CertificateConfig{
Web: w.fallback,
}
for _, t := range w.brands {
if t.WebCertificate.Get() == nil {
if !t.WebCertificate.IsSet() && len(t.GetClientCertificates()) < 1 {
continue
}
if *t.Default {
@ -82,11 +98,20 @@ func (w *Watcher) GetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, err
}
}
if bestSelection == nil {
return w.fallback, nil
return &config
}
cert := w.cs.Get(bestSelection.GetWebCertificate())
if cert == nil {
return w.fallback, nil
if bestSelection.GetWebCertificate() != "" {
if cert := w.cs.Get(bestSelection.GetWebCertificate()); cert != nil {
config.Web = cert
}
}
return cert, nil
if len(bestSelection.GetClientCertificates()) > 0 {
config.Client = x509.NewCertPool()
for _, kp := range bestSelection.GetClientCertificates() {
if cert := w.cs.Get(kp); cert != nil {
config.Client.AddCert(cert.Leaf)
}
}
}
return &config
}

View File

@ -1,15 +1,11 @@
package web
import (
"encoding/base64"
"fmt"
"io"
"net/http"
"os"
"path"
"github.com/gorilla/mux"
"github.com/gorilla/securecookie"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
@ -18,8 +14,6 @@ import (
"goauthentik.io/internal/utils/sentry"
)
const MetricsKeyFile = "authentik-core-metrics.key"
var Requests = promauto.NewHistogramVec(prometheus.HistogramOpts{
Name: "authentik_main_request_duration_seconds",
Help: "API request latencies in seconds",
@ -27,14 +21,6 @@ var Requests = promauto.NewHistogramVec(prometheus.HistogramOpts{
func (ws *WebServer) runMetricsServer() {
l := log.WithField("logger", "authentik.router.metrics")
tmp := os.TempDir()
key := base64.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(64))
keyPath := path.Join(tmp, MetricsKeyFile)
err := os.WriteFile(keyPath, []byte(key), 0o600)
if err != nil {
l.WithError(err).Warning("failed to save metrics key")
return
}
m := mux.NewRouter()
m.Use(sentry.SentryNoSampleMiddleware)
@ -51,7 +37,7 @@ func (ws *WebServer) runMetricsServer() {
l.WithError(err).Warning("failed to get upstream metrics")
return
}
re.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
re.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ws.metricsKey))
res, err := ws.upstreamHttpClient().Do(re)
if err != nil {
l.WithError(err).Warning("failed to get upstream metrics")
@ -64,13 +50,9 @@ func (ws *WebServer) runMetricsServer() {
}
})
l.WithField("listen", config.Get().Listen.Metrics).Info("Starting Metrics server")
err = http.ListenAndServe(config.Get().Listen.Metrics, m)
err := http.ListenAndServe(config.Get().Listen.Metrics, m)
if err != nil {
l.WithError(err).Warning("Failed to start metrics server")
}
l.WithField("listen", config.Get().Listen.Metrics).Info("Stopping Metrics server")
err = os.Remove(keyPath)
if err != nil {
l.WithError(err).Warning("failed to remove metrics key file")
}
}

View File

@ -2,21 +2,29 @@ package web
import (
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
"goauthentik.io/internal/config"
"goauthentik.io/internal/utils/sentry"
"goauthentik.io/internal/utils/web"
)
var (
ErrAuthentikStarting = errors.New("authentik starting")
)
const (
maxBodyBytes = 32 * 1024 * 1024
)
func (ws *WebServer) configureProxy() {
// Reverse proxy to the application server
director := func(req *http.Request) {
@ -26,8 +34,25 @@ func (ws *WebServer) configureProxy() {
// explicitly disable User-Agent so it's not set to default value
req.Header.Set("User-Agent", "")
}
if !web.IsRequestFromTrustedProxy(req) {
// If the request isn't coming from a trusted proxy, delete MTLS headers
req.Header.Del("SSL-Client-Cert") // nginx-ingress
req.Header.Del("X-Forwarded-TLS-Client-Cert") // traefik
req.Header.Del("X-Forwarded-Client-Cert") // envoy
}
if req.TLS != nil {
req.Header.Set("X-Forwarded-Proto", "https")
if len(req.TLS.PeerCertificates) > 0 {
pems := make([]string, len(req.TLS.PeerCertificates))
for i, crt := range req.TLS.PeerCertificates {
pem := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: crt.Raw,
})
pems[i] = "Cert=" + url.QueryEscape(string(pem))
}
req.Header.Set("X-Forwarded-Client-Cert", strings.Join(pems, ","))
}
}
ws.log.WithField("url", req.URL.String()).WithField("headers", req.Header).Trace("tracing request to backend")
}
@ -57,7 +82,7 @@ func (ws *WebServer) configureProxy() {
Requests.With(prometheus.Labels{
"dest": "core",
}).Observe(float64(elapsed) / float64(time.Second))
r.Body = http.MaxBytesReader(rw, r.Body, 32*1024*1024)
r.Body = http.MaxBytesReader(rw, r.Body, maxBodyBytes)
rp.ServeHTTP(rw, r)
}))
}

View File

@ -2,6 +2,7 @@ package web
import (
"context"
"encoding/base64"
"errors"
"fmt"
"net"
@ -13,11 +14,15 @@ import (
"github.com/gorilla/handlers"
"github.com/gorilla/mux"
"github.com/gorilla/securecookie"
"github.com/pires/go-proxyproto"
log "github.com/sirupsen/logrus"
"goauthentik.io/api/v3"
"goauthentik.io/internal/config"
"goauthentik.io/internal/constants"
"goauthentik.io/internal/gounicorn"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/proxyv2"
"goauthentik.io/internal/utils"
"goauthentik.io/internal/utils/web"
@ -25,6 +30,12 @@ import (
"goauthentik.io/internal/worker"
)
const (
IPCKeyFile = "authentik-core-ipc.key"
MetricsKeyFile = "authentik-core-metrics.key"
UnixSocketName = "authentik-core.sock"
)
type WebServer struct {
Bind string
BindTLS bool
@ -42,9 +53,10 @@ type WebServer struct {
log *log.Entry
upstreamClient *http.Client
upstreamURL *url.URL
}
const UnixSocketName = "authentik-core.sock"
metricsKey string
ipcKey string
}
func NewWebServer() *WebServer {
l := log.WithField("logger", "authentik.router")
@ -78,7 +90,7 @@ func NewWebServer() *WebServer {
mainRouter: mainHandler,
loggingRouter: loggingHandler,
log: l,
gunicornReady: true,
gunicornReady: false,
upstreamClient: upstreamClient,
upstreamURL: u,
}
@ -105,7 +117,59 @@ func NewWebServer() *WebServer {
return ws
}
func (ws *WebServer) prepareKeys() {
tmp := os.TempDir()
key := base64.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(64))
err := os.WriteFile(path.Join(tmp, MetricsKeyFile), []byte(key), 0o600)
if err != nil {
ws.log.WithError(err).Warning("failed to save metrics key")
return
}
ws.metricsKey = key
key = base64.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(64))
err = os.WriteFile(path.Join(tmp, IPCKeyFile), []byte(key), 0o600)
if err != nil {
ws.log.WithError(err).Warning("failed to save ipc key")
return
}
ws.ipcKey = key
}
func (ws *WebServer) Start() {
ws.prepareKeys()
u, err := url.Parse(fmt.Sprintf("http://%s%s", config.Get().Listen.HTTP, config.Get().Web.Path))
if err != nil {
panic(err)
}
apiConfig := api.NewConfiguration()
apiConfig.Host = u.Host
apiConfig.Scheme = u.Scheme
apiConfig.HTTPClient = &http.Client{
Transport: web.NewUserAgentTransport(
constants.UserAgentIPC(),
ak.GetTLSTransport(),
),
}
apiConfig.Servers = api.ServerConfigurations{
{
URL: fmt.Sprintf("%sapi/v3", u.Path),
},
}
apiConfig.AddDefaultHeader("Authorization", fmt.Sprintf("Bearer %s", ws.ipcKey))
// create the API client, with the transport
apiClient := api.NewAPIClient(apiConfig)
// Init brand_tls here too since it requires an API Client,
// so we just reuse the same one as the outpost uses
tw := brand_tls.NewWatcher(apiClient)
ws.BrandTLS = tw
ws.g.AddHealthyCallback(func() {
go tw.Start()
})
go ws.runMetricsServer()
go ws.attemptStartBackend()
go ws.attemptStartWorker()
@ -115,23 +179,23 @@ func (ws *WebServer) Start() {
func (ws *WebServer) attemptStartBackend() {
for {
if !ws.gunicornReady {
if ws.gunicornReady {
return
}
err := ws.g.Start()
log.WithField("logger", "authentik.router").WithError(err).Warning("gunicorn process died, restarting")
ws.log.WithError(err).Warning("gunicorn process died, restarting")
if err != nil {
log.WithField("logger", "authentik.router").WithError(err).Error("gunicorn failed to start, restarting")
ws.log.WithError(err).Error("gunicorn failed to start, restarting")
continue
}
failedChecks := 0
for range time.NewTicker(30 * time.Second).C {
if !ws.g.IsRunning() {
log.WithField("logger", "authentik.router").Warningf("gunicorn process failed healthcheck %d times", failedChecks)
ws.log.Warningf("gunicorn process failed healthcheck %d times", failedChecks)
failedChecks += 1
}
if failedChecks >= 3 {
log.WithField("logger", "authentik.router").WithError(err).Error("gunicorn process failed healthcheck three times, restarting")
ws.log.WithError(err).Error("gunicorn process failed healthcheck three times, restarting")
break
}
}
@ -155,6 +219,15 @@ func (ws *WebServer) upstreamHttpClient() *http.Client {
func (ws *WebServer) Shutdown() {
ws.log.Info("shutting down gunicorn")
ws.g.Kill()
tmp := os.TempDir()
err := os.Remove(path.Join(tmp, MetricsKeyFile))
if err != nil {
ws.log.WithError(err).Warning("failed to remove metrics key file")
}
err = os.Remove(path.Join(tmp, IPCKeyFile))
if err != nil {
ws.log.WithError(err).Warning("failed to remove ipc key file")
}
ws.stop <- struct{}{}
}

View File

@ -12,40 +12,57 @@ import (
"goauthentik.io/internal/utils/web"
)
func (ws *WebServer) GetCertificate() func(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := crypto.GenerateSelfSignedCert()
func (ws *WebServer) GetCertificate() func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
fallback, err := crypto.GenerateSelfSignedCert()
if err != nil {
ws.log.WithError(err).Error("failed to generate default cert")
}
return func(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
cfg := utils.GetTLSConfig()
if ch.ServerName == "" {
return &cert, nil
cfg.Certificates = []tls.Certificate{fallback}
return cfg, nil
}
if ws.ProxyServer != nil {
appCert := ws.ProxyServer.GetCertificate(ch.ServerName)
if appCert != nil {
return appCert, nil
cfg.Certificates = []tls.Certificate{*appCert}
return cfg, nil
}
}
if ws.BrandTLS != nil {
return ws.BrandTLS.GetCertificate(ch)
bcert := ws.BrandTLS.GetCertificate(ch)
cfg.Certificates = []tls.Certificate{*bcert.Web}
ws.log.Trace("using brand web Certificate")
if bcert.Client != nil {
cfg.ClientCAs = bcert.Client
cfg.ClientAuth = tls.RequestClientCert
ws.log.Trace("using brand client Certificate")
}
return cfg, nil
}
ws.log.Trace("using default, self-signed certificate")
return &cert, nil
cfg.Certificates = []tls.Certificate{fallback}
return cfg, nil
}
}
// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
func (ws *WebServer) listenTLS() {
tlsConfig := utils.GetTLSConfig()
tlsConfig.GetCertificate = ws.GetCertificate()
tlsConfig.GetConfigForClient = ws.GetCertificate()
ln, err := net.Listen("tcp", config.Get().Listen.HTTPS)
if err != nil {
ws.log.WithError(err).Warning("failed to listen (TLS)")
return
}
proxyListener := &proxyproto.Listener{Listener: web.TCPKeepAliveListener{TCPListener: ln.(*net.TCPListener)}, ConnPolicy: utils.GetProxyConnectionPolicy()}
proxyListener := &proxyproto.Listener{
Listener: web.TCPKeepAliveListener{
TCPListener: ln.(*net.TCPListener),
},
ConnPolicy: utils.GetProxyConnectionPolicy(),
}
defer func() {
err := proxyListener.Close()
if err != nil {