225 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			225 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package server
 | 
						|
 | 
						|
import (
 | 
						|
	"crypto/sha512"
 | 
						|
	"encoding/hex"
 | 
						|
	"math/rand"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"os"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/BeryJu/passbook/proxy/pkg"
 | 
						|
	"github.com/BeryJu/passbook/proxy/pkg/client"
 | 
						|
	"github.com/BeryJu/passbook/proxy/pkg/client/outposts"
 | 
						|
	"github.com/getsentry/sentry-go"
 | 
						|
	"github.com/go-openapi/runtime"
 | 
						|
	"github.com/recws-org/recws"
 | 
						|
 | 
						|
	httptransport "github.com/go-openapi/runtime/client"
 | 
						|
	"github.com/go-openapi/strfmt"
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
 | 
						|
	log "github.com/sirupsen/logrus"
 | 
						|
)
 | 
						|
 | 
						|
const ConfigLogLevel = "log_level"
 | 
						|
const ConfigErrorReportingEnabled = "error_reporting_enabled"
 | 
						|
const ConfigErrorReportingEnvironment = "error_reporting_environment"
 | 
						|
 | 
						|
// APIController main controller which connects to the passbook api via http and ws
 | 
						|
type APIController struct {
 | 
						|
	client *client.Passbook
 | 
						|
	auth   runtime.ClientAuthInfoWriter
 | 
						|
	token  string
 | 
						|
 | 
						|
	server *Server
 | 
						|
 | 
						|
	commonOpts *options.Options
 | 
						|
 | 
						|
	lastBundleHash string
 | 
						|
	logger         *log.Entry
 | 
						|
 | 
						|
	reloadOffset time.Duration
 | 
						|
 | 
						|
	wsConn *recws.RecConn
 | 
						|
}
 | 
						|
 | 
						|
func getCommonOptions() *options.Options {
 | 
						|
	commonOpts := options.NewOptions()
 | 
						|
	commonOpts.Cookie.Name = "passbook_proxy"
 | 
						|
	commonOpts.EmailDomains = []string{"*"}
 | 
						|
	commonOpts.ProviderType = "oidc"
 | 
						|
	commonOpts.ProxyPrefix = "/pbprox"
 | 
						|
	commonOpts.SkipProviderButton = true
 | 
						|
	commonOpts.Logging.SilencePing = true
 | 
						|
	commonOpts.SetAuthorization = false
 | 
						|
	commonOpts.Scope = "openid email profile pb_proxy"
 | 
						|
	return commonOpts
 | 
						|
}
 | 
						|
 | 
						|
func doGlobalSetup(config map[string]interface{}) {
 | 
						|
	switch config[ConfigLogLevel].(string) {
 | 
						|
	case "debug":
 | 
						|
		log.SetLevel(log.DebugLevel)
 | 
						|
	case "info":
 | 
						|
		log.SetLevel(log.InfoLevel)
 | 
						|
	case "warning":
 | 
						|
		log.SetLevel(log.WarnLevel)
 | 
						|
	case "error":
 | 
						|
		log.SetLevel(log.ErrorLevel)
 | 
						|
	default:
 | 
						|
		log.SetLevel(log.DebugLevel)
 | 
						|
	}
 | 
						|
	log.WithField("version", pkg.VERSION).Info("Starting passbook proxy")
 | 
						|
 | 
						|
	var dsn string
 | 
						|
	if config[ConfigErrorReportingEnabled].(bool) {
 | 
						|
		dsn = "https://33cdbcb23f8b436dbe0ee06847410b67@sentry.beryju.org/3"
 | 
						|
		log.Debug("Error reporting enabled")
 | 
						|
	}
 | 
						|
 | 
						|
	err := sentry.Init(sentry.ClientOptions{
 | 
						|
		Dsn:         dsn,
 | 
						|
		Environment: config[ConfigErrorReportingEnvironment].(string),
 | 
						|
	})
 | 
						|
	if err != nil {
 | 
						|
		log.Fatalf("sentry.Init: %s", err)
 | 
						|
	}
 | 
						|
 | 
						|
	defer sentry.Flush(2 * time.Second)
 | 
						|
}
 | 
						|
 | 
						|
func getTLSTransport() http.RoundTripper {
 | 
						|
	value, set := os.LookupEnv("PASSBOOK_INSECURE")
 | 
						|
	if !set {
 | 
						|
		value = "false"
 | 
						|
	}
 | 
						|
	tlsTransport, err := httptransport.TLSTransport(httptransport.TLSClientOptions{
 | 
						|
		InsecureSkipVerify: strings.ToLower(value) == "true",
 | 
						|
	})
 | 
						|
	if err != nil {
 | 
						|
		panic(err)
 | 
						|
	}
 | 
						|
	return tlsTransport
 | 
						|
}
 | 
						|
 | 
						|
// NewAPIController initialise new API Controller instance from URL and API token
 | 
						|
func NewAPIController(pbURL url.URL, token string) *APIController {
 | 
						|
	transport := httptransport.New(pbURL.Host, client.DefaultBasePath, []string{pbURL.Scheme})
 | 
						|
 | 
						|
	transport.Transport = getTLSTransport()
 | 
						|
 | 
						|
	// create the transport
 | 
						|
	auth := httptransport.BasicAuth("", token)
 | 
						|
 | 
						|
	// create the API client, with the transport
 | 
						|
	apiClient := client.New(transport, strfmt.Default)
 | 
						|
 | 
						|
	// Because we don't know the outpost UUID, we simply do a list and pick the first
 | 
						|
	// The service account this token belongs to should only have access to a single outpost
 | 
						|
	outposts, err := apiClient.Outposts.OutpostsOutpostsList(outposts.NewOutpostsOutpostsListParams(), auth)
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		panic(err)
 | 
						|
	}
 | 
						|
	outpost := outposts.Payload.Results[0]
 | 
						|
	doGlobalSetup(outpost.Config.(map[string]interface{}))
 | 
						|
 | 
						|
	ac := &APIController{
 | 
						|
		client: apiClient,
 | 
						|
		auth:   auth,
 | 
						|
		token:  token,
 | 
						|
 | 
						|
		logger:     log.WithField("component", "api-controller"),
 | 
						|
		commonOpts: getCommonOptions(),
 | 
						|
		server:     NewServer(),
 | 
						|
 | 
						|
		reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
 | 
						|
 | 
						|
		lastBundleHash: "",
 | 
						|
	}
 | 
						|
	ac.logger.Debugf("HA Reload offset: %s", ac.reloadOffset)
 | 
						|
	ac.initWS(pbURL, outpost.Pk)
 | 
						|
	return ac
 | 
						|
}
 | 
						|
 | 
						|
func (a *APIController) bundleProviders() ([]*providerBundle, error) {
 | 
						|
	providers, err := a.client.Outposts.OutpostsProxyList(outposts.NewOutpostsProxyListParams(), a.auth)
 | 
						|
	if err != nil {
 | 
						|
		a.logger.WithError(err).Error("Failed to fetch providers")
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	// Check provider hash to see if anything is changed
 | 
						|
	hasher := sha512.New()
 | 
						|
	bin, _ := providers.Payload.MarshalBinary()
 | 
						|
	hash := hex.EncodeToString(hasher.Sum(bin))
 | 
						|
	if hash == a.lastBundleHash {
 | 
						|
		return nil, nil
 | 
						|
	}
 | 
						|
	a.lastBundleHash = hash
 | 
						|
 | 
						|
	bundles := make([]*providerBundle, len(providers.Payload.Results))
 | 
						|
 | 
						|
	for idx, provider := range providers.Payload.Results {
 | 
						|
		externalHost, err := url.Parse(*provider.ExternalHost)
 | 
						|
		if err != nil {
 | 
						|
			log.WithError(err).Warning("Failed to parse URL, skipping provider")
 | 
						|
		}
 | 
						|
		bundles[idx] = &providerBundle{
 | 
						|
			a:    a,
 | 
						|
			Host: externalHost.Hostname(),
 | 
						|
		}
 | 
						|
		bundles[idx].Build(provider)
 | 
						|
	}
 | 
						|
	return bundles, nil
 | 
						|
}
 | 
						|
 | 
						|
func (a *APIController) updateHTTPServer(bundles []*providerBundle) {
 | 
						|
	newMap := make(map[string]*providerBundle)
 | 
						|
	for _, bundle := range bundles {
 | 
						|
		newMap[bundle.Host] = bundle
 | 
						|
	}
 | 
						|
	a.logger.Debug("Swapped maps")
 | 
						|
	a.server.Handlers = newMap
 | 
						|
}
 | 
						|
 | 
						|
// UpdateIfRequired Updates the HTTP Server config if required, automatically swaps the handlers
 | 
						|
func (a *APIController) UpdateIfRequired() error {
 | 
						|
	bundles, err := a.bundleProviders()
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	if bundles == nil {
 | 
						|
		a.logger.Debug("Providers have not changed, not updating")
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	a.updateHTTPServer(bundles)
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// Start Starts all handlers, non-blocking
 | 
						|
func (a *APIController) Start() error {
 | 
						|
	err := a.UpdateIfRequired()
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	go func() {
 | 
						|
		a.logger.Debug("Starting HTTP Server...")
 | 
						|
		a.server.ServeHTTP()
 | 
						|
	}()
 | 
						|
	go func() {
 | 
						|
		a.logger.Debug("Starting HTTPs Server...")
 | 
						|
		a.server.ServeHTTPS()
 | 
						|
	}()
 | 
						|
	go func() {
 | 
						|
		a.logger.Debug("Starting WS Handler...")
 | 
						|
		a.startWSHandler()
 | 
						|
	}()
 | 
						|
	go func() {
 | 
						|
		a.logger.Debug("Starting WS Health notifier...")
 | 
						|
		a.startWSHealth()
 | 
						|
	}()
 | 
						|
	return nil
 | 
						|
}
 |