* unrelated changes Signed-off-by: Jens Langhammer <jens@goauthentik.io> * optimization pass 1: reduce N tenant lookups by taking tenant from request, reduce get_anonymous calls Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix lint Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix Signed-off-by: Jens Langhammer <jens@goauthentik.io> * make it easier to exclude anonymous user Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix? Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
		
			
				
	
	
		
			87 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			87 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package brand_tls
 | 
						|
 | 
						|
import (
 | 
						|
	"crypto/tls"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	log "github.com/sirupsen/logrus"
 | 
						|
 | 
						|
	"goauthentik.io/api/v3"
 | 
						|
	"goauthentik.io/internal/crypto"
 | 
						|
	"goauthentik.io/internal/outpost/ak"
 | 
						|
)
 | 
						|
 | 
						|
type Watcher struct {
 | 
						|
	client   *api.APIClient
 | 
						|
	log      *log.Entry
 | 
						|
	cs       *ak.CryptoStore
 | 
						|
	fallback *tls.Certificate
 | 
						|
	brands   []api.Brand
 | 
						|
}
 | 
						|
 | 
						|
func NewWatcher(client *api.APIClient) *Watcher {
 | 
						|
	cs := ak.NewCryptoStore(client.CryptoApi)
 | 
						|
	l := log.WithField("logger", "authentik.router.brand_tls")
 | 
						|
	cert, err := crypto.GenerateSelfSignedCert()
 | 
						|
	if err != nil {
 | 
						|
		l.WithError(err).Error("failed to generate default cert")
 | 
						|
	}
 | 
						|
 | 
						|
	return &Watcher{
 | 
						|
		client:   client,
 | 
						|
		log:      l,
 | 
						|
		cs:       cs,
 | 
						|
		fallback: &cert,
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (w *Watcher) Start() {
 | 
						|
	ticker := time.NewTicker(time.Minute * 3)
 | 
						|
	w.log.Info("Starting Brand TLS Checker")
 | 
						|
	for ; true; <-ticker.C {
 | 
						|
		w.Check()
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (w *Watcher) Check() {
 | 
						|
	w.log.Info("updating brand certificates")
 | 
						|
	brands, _, err := w.client.CoreApi.CoreBrandsListExecute(api.ApiCoreBrandsListRequest{})
 | 
						|
	if err != nil {
 | 
						|
		w.log.WithError(err).Warning("failed to get brands")
 | 
						|
		return
 | 
						|
	}
 | 
						|
	for _, t := range brands.Results {
 | 
						|
		if kp := t.WebCertificate.Get(); kp != nil {
 | 
						|
			err := w.cs.AddKeypair(*kp)
 | 
						|
			if err != nil {
 | 
						|
				w.log.WithError(err).Warning("failed to add certificate")
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
	w.brands = brands.Results
 | 
						|
}
 | 
						|
 | 
						|
func (w *Watcher) GetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
 | 
						|
	var bestSelection *api.Brand
 | 
						|
	for _, t := range w.brands {
 | 
						|
		if t.WebCertificate.Get() == nil {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
		if *t.Default {
 | 
						|
			bestSelection = &t
 | 
						|
		}
 | 
						|
		if strings.HasSuffix(ch.ServerName, t.Domain) {
 | 
						|
			bestSelection = &t
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if bestSelection == nil {
 | 
						|
		return w.fallback, nil
 | 
						|
	}
 | 
						|
	cert := w.cs.Get(bestSelection.GetWebCertificate())
 | 
						|
	if cert == nil {
 | 
						|
		return w.fallback, nil
 | 
						|
	}
 | 
						|
	return cert, nil
 | 
						|
}
 |