outpost: separate ak-api and proxy further for future outposts
This commit is contained in:
		
							
								
								
									
										100
									
								
								outpost/pkg/ak/api.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								outpost/pkg/ak/api.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,100 @@
 | 
			
		||||
package ak
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg"
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg/client"
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg/client/outposts"
 | 
			
		||||
	"github.com/go-openapi/runtime"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/recws-org/recws"
 | 
			
		||||
 | 
			
		||||
	httptransport "github.com/go-openapi/runtime/client"
 | 
			
		||||
	"github.com/go-openapi/strfmt"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const ConfigLogLevel = "log_level"
 | 
			
		||||
const ConfigErrorReportingEnabled = "error_reporting_enabled"
 | 
			
		||||
const ConfigErrorReportingEnvironment = "error_reporting_environment"
 | 
			
		||||
 | 
			
		||||
// APIController main controller which connects to the authentik api via http and ws
 | 
			
		||||
type APIController struct {
 | 
			
		||||
	Client *client.Authentik
 | 
			
		||||
	Auth   runtime.ClientAuthInfoWriter
 | 
			
		||||
	token  string
 | 
			
		||||
 | 
			
		||||
	Server Outpost
 | 
			
		||||
 | 
			
		||||
	lastBundleHash string
 | 
			
		||||
	logger         *log.Entry
 | 
			
		||||
 | 
			
		||||
	reloadOffset time.Duration
 | 
			
		||||
 | 
			
		||||
	wsConn *recws.RecConn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewAPIController initialise new API Controller instance from URL and API token
 | 
			
		||||
func NewAPIController(pbURL url.URL, token string) *APIController {
 | 
			
		||||
	transport := httptransport.New(pbURL.Host, client.DefaultBasePath, []string{pbURL.Scheme})
 | 
			
		||||
	transport.Transport = SetUserAgent(getTLSTransport(), fmt.Sprintf("authentik-proxy@%s", pkg.VERSION))
 | 
			
		||||
 | 
			
		||||
	// create the transport
 | 
			
		||||
	auth := httptransport.BasicAuth("", token)
 | 
			
		||||
 | 
			
		||||
	// create the API client, with the transport
 | 
			
		||||
	apiClient := client.New(transport, strfmt.Default)
 | 
			
		||||
 | 
			
		||||
	// Because we don't know the outpost UUID, we simply do a list and pick the first
 | 
			
		||||
	// The service account this token belongs to should only have access to a single outpost
 | 
			
		||||
	outposts, err := apiClient.Outposts.OutpostsOutpostsList(outposts.NewOutpostsOutpostsListParams(), auth)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	outpost := outposts.Payload.Results[0]
 | 
			
		||||
	doGlobalSetup(outpost.Config.(map[string]interface{}))
 | 
			
		||||
 | 
			
		||||
	ac := &APIController{
 | 
			
		||||
		Client: apiClient,
 | 
			
		||||
		Auth:   auth,
 | 
			
		||||
		token:  token,
 | 
			
		||||
 | 
			
		||||
		logger: log.WithField("component", "ak-api-controller"),
 | 
			
		||||
 | 
			
		||||
		reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
 | 
			
		||||
 | 
			
		||||
		lastBundleHash: "",
 | 
			
		||||
	}
 | 
			
		||||
	ac.logger.Debugf("HA Reload offset: %s", ac.reloadOffset)
 | 
			
		||||
	ac.initWS(pbURL, outpost.Pk)
 | 
			
		||||
	return ac
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *APIController) GetLastBundleHash() string {
 | 
			
		||||
	return a.lastBundleHash
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Start Starts all handlers, non-blocking
 | 
			
		||||
func (a *APIController) Start() error {
 | 
			
		||||
	err := a.Server.Refresh()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.Wrap(err, "failed to run initial refresh")
 | 
			
		||||
	}
 | 
			
		||||
	go func() {
 | 
			
		||||
		a.logger.Debug("Starting WS Handler...")
 | 
			
		||||
		a.startWSHandler()
 | 
			
		||||
	}()
 | 
			
		||||
	go func() {
 | 
			
		||||
		a.logger.Debug("Starting WS Health notifier...")
 | 
			
		||||
		a.startWSHealth()
 | 
			
		||||
	}()
 | 
			
		||||
	go func() {
 | 
			
		||||
		a.Server.Start()
 | 
			
		||||
	}()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
package server
 | 
			
		||||
package ak
 | 
			
		||||
 | 
			
		||||
import "net/http"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										26
									
								
								outpost/pkg/ak/api_update.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								outpost/pkg/ak/api_update.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,26 @@
 | 
			
		||||
package ak
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/sha512"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg/client/outposts"
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg/models"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (a *APIController) Update() ([]*models.ProxyOutpostConfig, error) {
 | 
			
		||||
	providers, err := a.Client.Outposts.OutpostsProxyList(outposts.NewOutpostsProxyListParams(), a.Auth)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.logger.WithError(err).Error("Failed to fetch providers")
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	// Check provider hash to see if anything is changed
 | 
			
		||||
	hasher := sha512.New()
 | 
			
		||||
	bin, _ := providers.Payload.MarshalBinary()
 | 
			
		||||
	hash := hex.EncodeToString(hasher.Sum(bin))
 | 
			
		||||
	if hash == a.lastBundleHash {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	a.lastBundleHash = hash
 | 
			
		||||
	return providers.Payload.Results, nil
 | 
			
		||||
}
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
package server
 | 
			
		||||
package ak
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
@ -40,7 +40,7 @@ func (ac *APIController) initWS(pbURL url.URL, outpostUUID strfmt.UUID) {
 | 
			
		||||
	}
 | 
			
		||||
	ws.Dial(fmt.Sprintf(pathTemplate, scheme, pbURL.Host, outpostUUID.String()), header)
 | 
			
		||||
 | 
			
		||||
	ac.logger.WithField("component", "ws").WithField("outpost", outpostUUID.String()).Debug("connecting to authentik")
 | 
			
		||||
	ac.logger.WithField("component", "ak-ws").WithField("outpost", outpostUUID.String()).Debug("connecting to authentik")
 | 
			
		||||
 | 
			
		||||
	ac.wsConn = ws
 | 
			
		||||
	// Send hello message with our version
 | 
			
		||||
@ -52,7 +52,7 @@ func (ac *APIController) initWS(pbURL url.URL, outpostUUID strfmt.UUID) {
 | 
			
		||||
	}
 | 
			
		||||
	err := ws.WriteJSON(msg)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		ac.logger.WithField("component", "ws").WithError(err).Warning("Failed to hello to authentik")
 | 
			
		||||
		ac.logger.WithField("component", "ak-ws").WithError(err).Warning("Failed to hello to authentik")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -87,7 +87,7 @@ func (ac *APIController) startWSHandler() {
 | 
			
		||||
		}
 | 
			
		||||
		if wsMsg.Instruction == WebsocketInstructionTriggerUpdate {
 | 
			
		||||
			time.Sleep(ac.reloadOffset)
 | 
			
		||||
			err := ac.UpdateIfRequired()
 | 
			
		||||
			err := ac.Server.Refresh()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				ac.logger.WithField("loop", "ws-handler").WithError(err).Debug("Failed to update")
 | 
			
		||||
			}
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
package server
 | 
			
		||||
package ak
 | 
			
		||||
 | 
			
		||||
type websocketInstruction int
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
package server
 | 
			
		||||
package ak
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
@ -13,8 +13,8 @@ import (
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func generateSelfSignedCert() (tls.Certificate, error) {
 | 
			
		||||
 | 
			
		||||
// GenerateSelfSignedCert Generate a self-signed TLS Certificate, to be used as fallback
 | 
			
		||||
func GenerateSelfSignedCert() (tls.Certificate, error) {
 | 
			
		||||
	priv, err := rsa.GenerateKey(rand.Reader, 2048)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalf("Failed to generate private key: %v", err)
 | 
			
		||||
							
								
								
									
										60
									
								
								outpost/pkg/ak/global.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								outpost/pkg/ak/global.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,60 @@
 | 
			
		||||
package ak
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg"
 | 
			
		||||
	"github.com/getsentry/sentry-go"
 | 
			
		||||
	httptransport "github.com/go-openapi/runtime/client"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func doGlobalSetup(config map[string]interface{}) {
 | 
			
		||||
	log.SetFormatter(&log.JSONFormatter{})
 | 
			
		||||
	switch config[ConfigLogLevel].(string) {
 | 
			
		||||
	case "debug":
 | 
			
		||||
		log.SetLevel(log.DebugLevel)
 | 
			
		||||
	case "info":
 | 
			
		||||
		log.SetLevel(log.InfoLevel)
 | 
			
		||||
	case "warning":
 | 
			
		||||
		log.SetLevel(log.WarnLevel)
 | 
			
		||||
	case "error":
 | 
			
		||||
		log.SetLevel(log.ErrorLevel)
 | 
			
		||||
	default:
 | 
			
		||||
		log.SetLevel(log.DebugLevel)
 | 
			
		||||
	}
 | 
			
		||||
	log.WithField("version", pkg.VERSION).Info("Starting authentik proxy")
 | 
			
		||||
 | 
			
		||||
	var dsn string
 | 
			
		||||
	if config[ConfigErrorReportingEnabled].(bool) {
 | 
			
		||||
		dsn = "https://a579bb09306d4f8b8d8847c052d3a1d3@sentry.beryju.org/8"
 | 
			
		||||
		log.Debug("Error reporting enabled")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := sentry.Init(sentry.ClientOptions{
 | 
			
		||||
		Dsn:         dsn,
 | 
			
		||||
		Environment: config[ConfigErrorReportingEnvironment].(string),
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalf("sentry.Init: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer sentry.Flush(2 * time.Second)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getTLSTransport() http.RoundTripper {
 | 
			
		||||
	value, set := os.LookupEnv("AUTHENTIK_INSECURE")
 | 
			
		||||
	if !set {
 | 
			
		||||
		value = "false"
 | 
			
		||||
	}
 | 
			
		||||
	tlsTransport, err := httptransport.TLSTransport(httptransport.TLSClientOptions{
 | 
			
		||||
		InsecureSkipVerify: strings.ToLower(value) == "true",
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	return tlsTransport
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										6
									
								
								outpost/pkg/ak/outpost.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								outpost/pkg/ak/outpost.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,6 @@
 | 
			
		||||
package ak
 | 
			
		||||
 | 
			
		||||
type Outpost interface {
 | 
			
		||||
	Start() error
 | 
			
		||||
	Refresh() error
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										48
									
								
								outpost/pkg/proxy/api.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								outpost/pkg/proxy/api.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,48 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/url"
 | 
			
		||||
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg/models"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (s *Server) Refresh() error {
 | 
			
		||||
	providers, err := s.ak.Update()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if providers == nil {
 | 
			
		||||
		s.logger.Debug("Providers have not changed, not updating")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	bundles := s.bundleProviders(providers)
 | 
			
		||||
	s.updateHTTPServer(bundles)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) bundleProviders(providers []*models.ProxyOutpostConfig) []*providerBundle {
 | 
			
		||||
	bundles := make([]*providerBundle, len(providers))
 | 
			
		||||
	for idx, provider := range providers {
 | 
			
		||||
		externalHost, err := url.Parse(*provider.ExternalHost)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.WithError(err).Warning("Failed to parse URL, skipping provider")
 | 
			
		||||
		}
 | 
			
		||||
		bundles[idx] = &providerBundle{
 | 
			
		||||
			s:    s,
 | 
			
		||||
			Host: externalHost.Host,
 | 
			
		||||
			log:  log.WithField("component", "proxy-bundle").WithField("provider", provider.Name),
 | 
			
		||||
		}
 | 
			
		||||
		bundles[idx].Build(provider)
 | 
			
		||||
	}
 | 
			
		||||
	return bundles
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) updateHTTPServer(bundles []*providerBundle) {
 | 
			
		||||
	newMap := make(map[string]*providerBundle)
 | 
			
		||||
	for _, bundle := range bundles {
 | 
			
		||||
		newMap[bundle.Host] = bundle
 | 
			
		||||
	}
 | 
			
		||||
	s.logger.Debug("Swapped maps")
 | 
			
		||||
	s.Handlers = newMap
 | 
			
		||||
}
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
package server
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
@ -11,7 +11,6 @@ import (
 | 
			
		||||
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg/client/crypto"
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg/models"
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg/proxy"
 | 
			
		||||
	"github.com/jinzhu/copier"
 | 
			
		||||
	"github.com/justinas/alice"
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
 | 
			
		||||
@ -23,11 +22,13 @@ import (
 | 
			
		||||
type providerBundle struct {
 | 
			
		||||
	http.Handler
 | 
			
		||||
 | 
			
		||||
	a     *APIController
 | 
			
		||||
	proxy *proxy.OAuthProxy
 | 
			
		||||
	s     *Server
 | 
			
		||||
	proxy *OAuthProxy
 | 
			
		||||
	Host  string
 | 
			
		||||
 | 
			
		||||
	cert *tls.Certificate
 | 
			
		||||
 | 
			
		||||
	log *log.Entry
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pb *providerBundle) prepareOpts(provider *models.ProxyOutpostConfig) *options.Options {
 | 
			
		||||
@ -37,7 +38,7 @@ func (pb *providerBundle) prepareOpts(provider *models.ProxyOutpostConfig) *opti
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	providerOpts := &options.Options{}
 | 
			
		||||
	copier.Copy(&providerOpts, &pb.a.commonOpts)
 | 
			
		||||
	copier.Copy(&providerOpts, getCommonOptions())
 | 
			
		||||
	providerOpts.ClientID = provider.ClientID
 | 
			
		||||
	providerOpts.ClientSecret = provider.ClientSecret
 | 
			
		||||
 | 
			
		||||
@ -66,22 +67,22 @@ func (pb *providerBundle) prepareOpts(provider *models.ProxyOutpostConfig) *opti
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if provider.Certificate != nil {
 | 
			
		||||
		pb.a.logger.WithField("provider", provider.ClientID).Debug("Enabling TLS")
 | 
			
		||||
		cert, err := pb.a.client.Crypto.CryptoCertificatekeypairsRead(&crypto.CryptoCertificatekeypairsReadParams{
 | 
			
		||||
		pb.log.WithField("provider", provider.ClientID).Debug("Enabling TLS")
 | 
			
		||||
		cert, err := pb.s.ak.Client.Crypto.CryptoCertificatekeypairsRead(&crypto.CryptoCertificatekeypairsReadParams{
 | 
			
		||||
			Context: context.Background(),
 | 
			
		||||
			KpUUID:  *provider.Certificate,
 | 
			
		||||
		}, pb.a.auth)
 | 
			
		||||
		}, pb.s.ak.Auth)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			pb.a.logger.WithField("provider", provider.ClientID).WithError(err).Warning("Failed to fetch certificate")
 | 
			
		||||
			pb.log.WithField("provider", provider.ClientID).WithError(err).Warning("Failed to fetch certificate")
 | 
			
		||||
			return providerOpts
 | 
			
		||||
		}
 | 
			
		||||
		x509cert, err := tls.X509KeyPair([]byte(*cert.Payload.CertificateData), []byte(cert.Payload.KeyData))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			pb.a.logger.WithField("provider", provider.ClientID).WithError(err).Warning("Failed to parse certificate")
 | 
			
		||||
			pb.log.WithField("provider", provider.ClientID).WithError(err).Warning("Failed to parse certificate")
 | 
			
		||||
			return providerOpts
 | 
			
		||||
		}
 | 
			
		||||
		pb.cert = &x509cert
 | 
			
		||||
		pb.a.logger.WithField("provider", provider.ClientID).WithField("certificate-key-pair", *cert.Payload.Name).Debug("Loaded certificates")
 | 
			
		||||
		pb.log.WithField("provider", provider.ClientID).WithField("certificate-key-pair", *cert.Payload.Name).Debug("Loaded certificates")
 | 
			
		||||
	}
 | 
			
		||||
	return providerOpts
 | 
			
		||||
}
 | 
			
		||||
@ -119,7 +120,7 @@ func (pb *providerBundle) Build(provider *models.ProxyOutpostConfig) {
 | 
			
		||||
		log.Printf("%s", err)
 | 
			
		||||
		os.Exit(1)
 | 
			
		||||
	}
 | 
			
		||||
	oauthproxy, err := proxy.NewOAuthProxy(opts)
 | 
			
		||||
	oauthproxy, err := NewOAuthProxy(opts)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Errorf("ERROR: Failed to initialise OAuth2 Proxy: %v", err)
 | 
			
		||||
		os.Exit(1)
 | 
			
		||||
							
								
								
									
										20
									
								
								outpost/pkg/proxy/common.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								outpost/pkg/proxy/common.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,20 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func getCommonOptions() *options.Options {
 | 
			
		||||
	commonOpts := options.NewOptions()
 | 
			
		||||
	commonOpts.Cookie.Name = "authentik_proxy"
 | 
			
		||||
	commonOpts.Cookie.Expire = 24 * time.Hour
 | 
			
		||||
	commonOpts.EmailDomains = []string{"*"}
 | 
			
		||||
	commonOpts.ProviderType = "oidc"
 | 
			
		||||
	commonOpts.ProxyPrefix = "/akprox"
 | 
			
		||||
	commonOpts.Logging.SilencePing = true
 | 
			
		||||
	commonOpts.SetAuthorization = false
 | 
			
		||||
	commonOpts.Scope = "openid email profile ak_proxy"
 | 
			
		||||
	return commonOpts
 | 
			
		||||
}
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
package server
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
@ -95,7 +95,7 @@ type loggingHandler struct {
 | 
			
		||||
func LoggingHandler(h http.Handler) http.Handler {
 | 
			
		||||
	return loggingHandler{
 | 
			
		||||
		handler: h,
 | 
			
		||||
		logger:  log.WithField("component", "http-server"),
 | 
			
		||||
		logger:  log.WithField("component", "proxy-http-server"),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
package server
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
@ -8,6 +8,7 @@ import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg/ak"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -15,70 +16,26 @@ import (
 | 
			
		||||
type Server struct {
 | 
			
		||||
	Handlers map[string]*providerBundle
 | 
			
		||||
 | 
			
		||||
	stop   chan struct{} // channel for waiting shutdown
 | 
			
		||||
	logger *log.Entry
 | 
			
		||||
 | 
			
		||||
	stop        chan struct{} // channel for waiting shutdown
 | 
			
		||||
	logger      *log.Entry
 | 
			
		||||
	ak          *ak.APIController
 | 
			
		||||
	defaultCert tls.Certificate
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewServer initialise a new HTTP Server
 | 
			
		||||
func NewServer() *Server {
 | 
			
		||||
	defaultCert, err := generateSelfSignedCert()
 | 
			
		||||
func NewServer(ac *ak.APIController) *Server {
 | 
			
		||||
	defaultCert, err := ak.GenerateSelfSignedCert()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Warning(err)
 | 
			
		||||
	}
 | 
			
		||||
	return &Server{
 | 
			
		||||
		Handlers:    make(map[string]*providerBundle),
 | 
			
		||||
		logger:      log.WithField("component", "http-server"),
 | 
			
		||||
		logger:      log.WithField("component", "proxy-http-server"),
 | 
			
		||||
		defaultCert: defaultCert,
 | 
			
		||||
		ak:          ac,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
 | 
			
		||||
func (s *Server) ServeHTTP() {
 | 
			
		||||
	listenAddress := "0.0.0.0:4180"
 | 
			
		||||
	listener, err := net.Listen("tcp", listenAddress)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Fatalf("FATAL: listen (%s) failed - %s", listenAddress, err)
 | 
			
		||||
	}
 | 
			
		||||
	s.logger.Printf("listening on %s", listener.Addr())
 | 
			
		||||
	s.serve(listener)
 | 
			
		||||
	s.logger.Printf("closing %s", listener.Addr())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) getCertificates(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
 | 
			
		||||
	handler, ok := s.Handlers[info.ServerName]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		s.logger.WithField("server-name", info.ServerName).Debug("Handler does not exist")
 | 
			
		||||
		return &s.defaultCert, nil
 | 
			
		||||
	}
 | 
			
		||||
	if handler.cert == nil {
 | 
			
		||||
		s.logger.WithField("server-name", info.ServerName).Debug("Handler does not have a certificate")
 | 
			
		||||
		return &s.defaultCert, nil
 | 
			
		||||
	}
 | 
			
		||||
	return handler.cert, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
 | 
			
		||||
func (s *Server) ServeHTTPS() {
 | 
			
		||||
	listenAddress := "0.0.0.0:4443"
 | 
			
		||||
	config := &tls.Config{
 | 
			
		||||
		MinVersion:     tls.VersionTLS12,
 | 
			
		||||
		MaxVersion:     tls.VersionTLS12,
 | 
			
		||||
		GetCertificate: s.getCertificates,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ln, err := net.Listen("tcp", listenAddress)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Fatalf("FATAL: listen (%s) failed - %s", listenAddress, err)
 | 
			
		||||
	}
 | 
			
		||||
	s.logger.Printf("listening on %s", ln.Addr())
 | 
			
		||||
 | 
			
		||||
	tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
 | 
			
		||||
	s.serve(tlsListener)
 | 
			
		||||
	s.logger.Printf("closing %s", tlsListener.Addr())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	if r.URL.Path == "/akprox/ping" {
 | 
			
		||||
		w.WriteHeader(204)
 | 
			
		||||
							
								
								
									
										68
									
								
								outpost/pkg/proxy/server_https.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								outpost/pkg/proxy/server_https.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,68 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
 | 
			
		||||
func (s *Server) ServeHTTP() {
 | 
			
		||||
	listenAddress := "0.0.0.0:4180"
 | 
			
		||||
	listener, err := net.Listen("tcp", listenAddress)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Fatalf("FATAL: listen (%s) failed - %s", listenAddress, err)
 | 
			
		||||
	}
 | 
			
		||||
	s.logger.Printf("listening on %s", listener.Addr())
 | 
			
		||||
	s.serve(listener)
 | 
			
		||||
	s.logger.Printf("closing %s", listener.Addr())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) getCertificates(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
 | 
			
		||||
	handler, ok := s.Handlers[info.ServerName]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		s.logger.WithField("server-name", info.ServerName).Debug("Handler does not exist")
 | 
			
		||||
		return &s.defaultCert, nil
 | 
			
		||||
	}
 | 
			
		||||
	if handler.cert == nil {
 | 
			
		||||
		s.logger.WithField("server-name", info.ServerName).Debug("Handler does not have a certificate")
 | 
			
		||||
		return &s.defaultCert, nil
 | 
			
		||||
	}
 | 
			
		||||
	return handler.cert, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
 | 
			
		||||
func (s *Server) ServeHTTPS() {
 | 
			
		||||
	listenAddress := "0.0.0.0:4443"
 | 
			
		||||
	config := &tls.Config{
 | 
			
		||||
		MinVersion:     tls.VersionTLS12,
 | 
			
		||||
		MaxVersion:     tls.VersionTLS12,
 | 
			
		||||
		GetCertificate: s.getCertificates,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ln, err := net.Listen("tcp", listenAddress)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Fatalf("FATAL: listen (%s) failed - %s", listenAddress, err)
 | 
			
		||||
	}
 | 
			
		||||
	s.logger.Printf("listening on %s", ln.Addr())
 | 
			
		||||
 | 
			
		||||
	tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
 | 
			
		||||
	s.serve(tlsListener)
 | 
			
		||||
	s.logger.Printf("closing %s", tlsListener.Addr())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) Start() error {
 | 
			
		||||
	wg := sync.WaitGroup{}
 | 
			
		||||
	wg.Add(2)
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer wg.Done()
 | 
			
		||||
		s.logger.Debug("Starting HTTP Server...")
 | 
			
		||||
		s.ServeHTTP()
 | 
			
		||||
	}()
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer wg.Done()
 | 
			
		||||
		s.logger.Debug("Starting HTTPs Server...")
 | 
			
		||||
		s.ServeHTTPS()
 | 
			
		||||
	}()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@ -1,225 +0,0 @@
 | 
			
		||||
package server
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/sha512"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg"
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg/client"
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg/client/outposts"
 | 
			
		||||
	"github.com/getsentry/sentry-go"
 | 
			
		||||
	"github.com/go-openapi/runtime"
 | 
			
		||||
	"github.com/recws-org/recws"
 | 
			
		||||
 | 
			
		||||
	httptransport "github.com/go-openapi/runtime/client"
 | 
			
		||||
	"github.com/go-openapi/strfmt"
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const ConfigLogLevel = "log_level"
 | 
			
		||||
const ConfigErrorReportingEnabled = "error_reporting_enabled"
 | 
			
		||||
const ConfigErrorReportingEnvironment = "error_reporting_environment"
 | 
			
		||||
 | 
			
		||||
// APIController main controller which connects to the authentik api via http and ws
 | 
			
		||||
type APIController struct {
 | 
			
		||||
	client *client.Authentik
 | 
			
		||||
	auth   runtime.ClientAuthInfoWriter
 | 
			
		||||
	token  string
 | 
			
		||||
 | 
			
		||||
	server *Server
 | 
			
		||||
 | 
			
		||||
	commonOpts *options.Options
 | 
			
		||||
 | 
			
		||||
	lastBundleHash string
 | 
			
		||||
	logger         *log.Entry
 | 
			
		||||
 | 
			
		||||
	reloadOffset time.Duration
 | 
			
		||||
 | 
			
		||||
	wsConn *recws.RecConn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getCommonOptions() *options.Options {
 | 
			
		||||
	commonOpts := options.NewOptions()
 | 
			
		||||
	commonOpts.Cookie.Name = "authentik_proxy"
 | 
			
		||||
	commonOpts.Cookie.Expire = 24 * time.Hour
 | 
			
		||||
	commonOpts.EmailDomains = []string{"*"}
 | 
			
		||||
	commonOpts.ProviderType = "oidc"
 | 
			
		||||
	commonOpts.ProxyPrefix = "/akprox"
 | 
			
		||||
	commonOpts.Logging.SilencePing = true
 | 
			
		||||
	commonOpts.SetAuthorization = false
 | 
			
		||||
	commonOpts.Scope = "openid email profile ak_proxy"
 | 
			
		||||
	return commonOpts
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func doGlobalSetup(config map[string]interface{}) {
 | 
			
		||||
	log.SetFormatter(&log.JSONFormatter{})
 | 
			
		||||
	switch config[ConfigLogLevel].(string) {
 | 
			
		||||
	case "debug":
 | 
			
		||||
		log.SetLevel(log.DebugLevel)
 | 
			
		||||
	case "info":
 | 
			
		||||
		log.SetLevel(log.InfoLevel)
 | 
			
		||||
	case "warning":
 | 
			
		||||
		log.SetLevel(log.WarnLevel)
 | 
			
		||||
	case "error":
 | 
			
		||||
		log.SetLevel(log.ErrorLevel)
 | 
			
		||||
	default:
 | 
			
		||||
		log.SetLevel(log.DebugLevel)
 | 
			
		||||
	}
 | 
			
		||||
	log.WithField("version", pkg.VERSION).Info("Starting authentik proxy")
 | 
			
		||||
 | 
			
		||||
	var dsn string
 | 
			
		||||
	if config[ConfigErrorReportingEnabled].(bool) {
 | 
			
		||||
		dsn = "https://a579bb09306d4f8b8d8847c052d3a1d3@sentry.beryju.org/8"
 | 
			
		||||
		log.Debug("Error reporting enabled")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := sentry.Init(sentry.ClientOptions{
 | 
			
		||||
		Dsn:         dsn,
 | 
			
		||||
		Environment: config[ConfigErrorReportingEnvironment].(string),
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalf("sentry.Init: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer sentry.Flush(2 * time.Second)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getTLSTransport() http.RoundTripper {
 | 
			
		||||
	value, set := os.LookupEnv("AUTHENTIK_INSECURE")
 | 
			
		||||
	if !set {
 | 
			
		||||
		value = "false"
 | 
			
		||||
	}
 | 
			
		||||
	tlsTransport, err := httptransport.TLSTransport(httptransport.TLSClientOptions{
 | 
			
		||||
		InsecureSkipVerify: strings.ToLower(value) == "true",
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	return tlsTransport
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewAPIController initialise new API Controller instance from URL and API token
 | 
			
		||||
func NewAPIController(pbURL url.URL, token string) *APIController {
 | 
			
		||||
	transport := httptransport.New(pbURL.Host, client.DefaultBasePath, []string{pbURL.Scheme})
 | 
			
		||||
	transport.Transport = SetUserAgent(getTLSTransport(), fmt.Sprintf("authentik-proxy@%s", pkg.VERSION))
 | 
			
		||||
 | 
			
		||||
	// create the transport
 | 
			
		||||
	auth := httptransport.BasicAuth("", token)
 | 
			
		||||
 | 
			
		||||
	// create the API client, with the transport
 | 
			
		||||
	apiClient := client.New(transport, strfmt.Default)
 | 
			
		||||
 | 
			
		||||
	// Because we don't know the outpost UUID, we simply do a list and pick the first
 | 
			
		||||
	// The service account this token belongs to should only have access to a single outpost
 | 
			
		||||
	outposts, err := apiClient.Outposts.OutpostsOutpostsList(outposts.NewOutpostsOutpostsListParams(), auth)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	outpost := outposts.Payload.Results[0]
 | 
			
		||||
	doGlobalSetup(outpost.Config.(map[string]interface{}))
 | 
			
		||||
 | 
			
		||||
	ac := &APIController{
 | 
			
		||||
		client: apiClient,
 | 
			
		||||
		auth:   auth,
 | 
			
		||||
		token:  token,
 | 
			
		||||
 | 
			
		||||
		logger:     log.WithField("component", "api-controller"),
 | 
			
		||||
		commonOpts: getCommonOptions(),
 | 
			
		||||
		server:     NewServer(),
 | 
			
		||||
 | 
			
		||||
		reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
 | 
			
		||||
 | 
			
		||||
		lastBundleHash: "",
 | 
			
		||||
	}
 | 
			
		||||
	ac.logger.Debugf("HA Reload offset: %s", ac.reloadOffset)
 | 
			
		||||
	ac.initWS(pbURL, outpost.Pk)
 | 
			
		||||
	return ac
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *APIController) bundleProviders() ([]*providerBundle, error) {
 | 
			
		||||
	providers, err := a.client.Outposts.OutpostsProxyList(outposts.NewOutpostsProxyListParams(), a.auth)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		a.logger.WithError(err).Error("Failed to fetch providers")
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	// Check provider hash to see if anything is changed
 | 
			
		||||
	hasher := sha512.New()
 | 
			
		||||
	bin, _ := providers.Payload.MarshalBinary()
 | 
			
		||||
	hash := hex.EncodeToString(hasher.Sum(bin))
 | 
			
		||||
	if hash == a.lastBundleHash {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	a.lastBundleHash = hash
 | 
			
		||||
 | 
			
		||||
	bundles := make([]*providerBundle, len(providers.Payload.Results))
 | 
			
		||||
 | 
			
		||||
	for idx, provider := range providers.Payload.Results {
 | 
			
		||||
		externalHost, err := url.Parse(*provider.ExternalHost)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.WithError(err).Warning("Failed to parse URL, skipping provider")
 | 
			
		||||
		}
 | 
			
		||||
		bundles[idx] = &providerBundle{
 | 
			
		||||
			a:    a,
 | 
			
		||||
			Host: externalHost.Host,
 | 
			
		||||
		}
 | 
			
		||||
		bundles[idx].Build(provider)
 | 
			
		||||
	}
 | 
			
		||||
	return bundles, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *APIController) updateHTTPServer(bundles []*providerBundle) {
 | 
			
		||||
	newMap := make(map[string]*providerBundle)
 | 
			
		||||
	for _, bundle := range bundles {
 | 
			
		||||
		newMap[bundle.Host] = bundle
 | 
			
		||||
	}
 | 
			
		||||
	a.logger.Debug("Swapped maps")
 | 
			
		||||
	a.server.Handlers = newMap
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateIfRequired Updates the HTTP Server config if required, automatically swaps the handlers
 | 
			
		||||
func (a *APIController) UpdateIfRequired() error {
 | 
			
		||||
	bundles, err := a.bundleProviders()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if bundles == nil {
 | 
			
		||||
		a.logger.Debug("Providers have not changed, not updating")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	a.updateHTTPServer(bundles)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Start Starts all handlers, non-blocking
 | 
			
		||||
func (a *APIController) Start() error {
 | 
			
		||||
	err := a.UpdateIfRequired()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	go func() {
 | 
			
		||||
		a.logger.Debug("Starting HTTP Server...")
 | 
			
		||||
		a.server.ServeHTTP()
 | 
			
		||||
	}()
 | 
			
		||||
	go func() {
 | 
			
		||||
		a.logger.Debug("Starting HTTPs Server...")
 | 
			
		||||
		a.server.ServeHTTPS()
 | 
			
		||||
	}()
 | 
			
		||||
	go func() {
 | 
			
		||||
		a.logger.Debug("Starting WS Handler...")
 | 
			
		||||
		a.startWSHandler()
 | 
			
		||||
	}()
 | 
			
		||||
	go func() {
 | 
			
		||||
		a.logger.Debug("Starting WS Health notifier...")
 | 
			
		||||
		a.startWSHealth()
 | 
			
		||||
	}()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user