This commit is contained in:
Jens L
2020-09-03 00:04:12 +02:00
committed by GitHub
parent 14e47f3195
commit 268de20872
105 changed files with 6243 additions and 497 deletions

212
proxy/pkg/server/api.go Normal file
View File

@ -0,0 +1,212 @@
package server
import (
"crypto/sha512"
"encoding/hex"
"net/http"
"net/url"
"os"
"time"
"github.com/BeryJu/passbook/proxy/pkg/client"
"github.com/BeryJu/passbook/proxy/pkg/client/outposts"
"github.com/getsentry/sentry-go"
"github.com/go-openapi/runtime"
"github.com/recws-org/recws"
httptransport "github.com/go-openapi/runtime/client"
"github.com/go-openapi/strfmt"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
log "github.com/sirupsen/logrus"
)
const ConfigLogLevel = "log_level"
const ConfigErrorReportingEnabled = "error_reporting_enabled"
const ConfigErrorReportingEnvironment = "error_reporting_environment"
// APIController main controller which connects to the passbook api via http and ws
type APIController struct {
client *client.Passbook
auth runtime.ClientAuthInfoWriter
token string
server *Server
commonOpts *options.Options
lastBundleHash string
logger *log.Entry
wsConn recws.RecConn
}
func getCommonOptions() *options.Options {
commonOpts := options.NewOptions()
commonOpts.Cookie.Name = "passbook_proxy"
commonOpts.EmailDomains = []string{"*"}
commonOpts.ProviderType = "oidc"
commonOpts.ProxyPrefix = "/pbprox"
commonOpts.SkipProviderButton = true
commonOpts.Logging.SilencePing = true
commonOpts.SetXAuthRequest = true
commonOpts.SetAuthorization = true
return commonOpts
}
func doGlobalSetup(config map[string]interface{}) {
switch config[ConfigLogLevel].(string) {
case "debug":
log.SetLevel(log.DebugLevel)
case "info":
log.SetLevel(log.InfoLevel)
case "warning":
log.SetLevel(log.WarnLevel)
case "error":
log.SetLevel(log.ErrorLevel)
default:
log.SetLevel(log.DebugLevel)
}
var dsn string
if config[ConfigErrorReportingEnabled].(bool) {
dsn = "https://33cdbcb23f8b436dbe0ee06847410b67@sentry.beryju.org/3"
log.Debug("Error reporting enabled")
}
err := sentry.Init(sentry.ClientOptions{
Dsn: dsn,
Environment: config[ConfigErrorReportingEnvironment].(string),
})
if err != nil {
log.Fatalf("sentry.Init: %s", err)
}
defer sentry.Flush(2 * time.Second)
}
func getTLSTransport() http.RoundTripper {
_, set := os.LookupEnv("PASSBOOK_INSECURE")
tlsTransport, err := httptransport.TLSTransport(httptransport.TLSClientOptions{
InsecureSkipVerify: set,
})
if err != nil {
panic(err)
}
return tlsTransport
}
// NewAPIController initialise new API Controller instance from URL and API token
func NewAPIController(pbURL url.URL, token string) *APIController {
transport := httptransport.New(pbURL.Host, client.DefaultBasePath, []string{pbURL.Scheme})
transport.Transport = getTLSTransport()
// create the transport
auth := httptransport.BasicAuth("", token)
// create the API client, with the transport
apiClient := client.New(transport, strfmt.Default)
// Because we don't know the outpost UUID, we simply do a list and pick the first
// The service account this token belongs to should only have access to a single outpost
outposts, err := apiClient.Outposts.OutpostsOutpostsList(outposts.NewOutpostsOutpostsListParams(), auth)
if err != nil {
panic(err)
}
outpost := outposts.Payload.Results[0]
doGlobalSetup(outpost.Config.(map[string]interface{}))
ac := &APIController{
client: apiClient,
auth: auth,
token: token,
logger: log.WithField("component", "api-controller"),
commonOpts: getCommonOptions(),
server: NewServer(),
lastBundleHash: "",
}
ac.initWS(pbURL, outpost.Pk)
return ac
}
func (a *APIController) bundleProviders() ([]*providerBundle, error) {
providers, err := a.client.Outposts.OutpostsProxyList(outposts.NewOutpostsProxyListParams(), a.auth)
if err != nil {
a.logger.WithError(err).Error("Failed to fetch providers")
return nil, err
}
// Check provider hash to see if anything is changed
hasher := sha512.New()
bin, _ := providers.Payload.MarshalBinary()
hash := hex.EncodeToString(hasher.Sum(bin))
if hash == a.lastBundleHash {
return nil, nil
}
a.lastBundleHash = hash
bundles := make([]*providerBundle, len(providers.Payload.Results))
for idx, provider := range providers.Payload.Results {
externalHost, err := url.Parse(*provider.ExternalHost)
if err != nil {
log.WithError(err).Warning("Failed to parse URL, skipping provider")
}
bundles[idx] = &providerBundle{
a: a,
Host: externalHost.Hostname(),
}
bundles[idx].Build(provider)
}
return bundles, nil
}
func (a *APIController) updateHTTPServer(bundles []*providerBundle) {
newMap := make(map[string]*providerBundle)
for _, bundle := range bundles {
newMap[bundle.Host] = bundle
}
a.logger.Debug("Swapped maps")
a.server.Handlers = newMap
}
// UpdateIfRequired Updates the HTTP Server config if required, automatically swaps the handlers
func (a *APIController) UpdateIfRequired() error {
bundles, err := a.bundleProviders()
if err != nil {
return err
}
if bundles == nil {
a.logger.Debug("Providers have not changed, not updating")
return nil
}
a.updateHTTPServer(bundles)
return nil
}
// Start Starts all handlers, non-blocking
func (a *APIController) Start() error {
err := a.UpdateIfRequired()
if err != nil {
return err
}
go func() {
a.logger.Debug("Starting HTTP Server...")
a.server.ServeHTTP()
}()
go func() {
a.logger.Debug("Starting HTTPs Server...")
a.server.ServeHTTPS()
}()
go func() {
a.logger.Debug("Starting WS Handler...")
a.startWSHandler()
}()
go func() {
a.logger.Debug("Starting WS Health notifier...")
a.startWSHealth()
}()
return nil
}

View File

@ -0,0 +1,123 @@
package server
import (
"context"
"crypto/tls"
"net"
"net/http"
"net/url"
"os"
"github.com/BeryJu/passbook/proxy/pkg/client/crypto"
"github.com/BeryJu/passbook/proxy/pkg/models"
"github.com/BeryJu/passbook/proxy/pkg/proxy"
"github.com/jinzhu/copier"
"github.com/justinas/alice"
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware"
"github.com/oauth2-proxy/oauth2-proxy/pkg/validation"
log "github.com/sirupsen/logrus"
)
type providerBundle struct {
http.Handler
a *APIController
proxy *proxy.OAuthProxy
Host string
cert *tls.Certificate
}
func (pb *providerBundle) prepareOpts(provider *models.ProxyOutpostConfig) *options.Options {
externalHost, err := url.Parse(*provider.ExternalHost)
if err != nil {
log.WithError(err).Warning("Failed to parse URL, skipping provider")
return nil
}
providerOpts := &options.Options{}
copier.Copy(&providerOpts, &pb.a.commonOpts)
providerOpts.ClientID = provider.ClientID
providerOpts.ClientSecret = provider.ClientSecret
providerOpts.Cookie.Secret = provider.CookieSecret
providerOpts.Cookie.Secure = externalHost.Scheme == "https"
providerOpts.SkipOIDCDiscovery = true
providerOpts.OIDCIssuerURL = *provider.OidcConfiguration.Issuer
providerOpts.LoginURL = *provider.OidcConfiguration.AuthorizationEndpoint
providerOpts.RedeemURL = *provider.OidcConfiguration.TokenEndpoint
providerOpts.OIDCJwksURL = *provider.OidcConfiguration.JwksURI
providerOpts.ProfileURL = *provider.OidcConfiguration.UserinfoEndpoint
providerOpts.UpstreamServers = []options.Upstream{
{
ID: "default",
URI: *provider.InternalHost,
Path: "/",
},
}
if provider.Certificate != nil {
pb.a.logger.WithField("provider", provider.ClientID).Debug("Enabling TLS")
cert, err := pb.a.client.Crypto.CryptoCertificatekeypairsRead(&crypto.CryptoCertificatekeypairsReadParams{
Context: context.Background(),
KpUUID: *provider.Certificate,
}, pb.a.auth)
if err != nil {
pb.a.logger.WithField("provider", provider.ClientID).WithError(err).Warning("Failed to fetch certificate")
return providerOpts
}
x509cert, err := tls.X509KeyPair([]byte(*cert.Payload.CertificateData), []byte(cert.Payload.KeyData))
if err != nil {
pb.a.logger.WithField("provider", provider.ClientID).WithError(err).Warning("Failed to parse certificate")
return providerOpts
}
pb.cert = &x509cert
pb.a.logger.WithField("provider", provider.ClientID).WithField("certificate-key-pair", *cert.Payload.Name).Debug("Loaded certificates")
}
return providerOpts
}
func (pb *providerBundle) Build(provider *models.ProxyOutpostConfig) {
opts := pb.prepareOpts(provider)
chain := alice.New()
if opts.ForceHTTPS {
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
if err != nil {
log.Fatalf("FATAL: invalid HTTPS address %q: %v", opts.HTTPAddress, err)
}
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
}
healthCheckPaths := []string{opts.PingPath}
healthCheckUserAgents := []string{opts.PingUserAgent}
if opts.GCPHealthChecks {
healthCheckPaths = append(healthCheckPaths, "/liveness_check", "/readiness_check")
healthCheckUserAgents = append(healthCheckUserAgents, "GoogleHC/1.0")
}
// To silence logging of health checks, register the health check handler before
// the logging handler
if opts.Logging.SilencePing {
chain = chain.Append(middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), LoggingHandler)
} else {
chain = chain.Append(LoggingHandler, middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents))
}
err := validation.Validate(opts)
if err != nil {
log.Printf("%s", err)
os.Exit(1)
}
oauthproxy, err := proxy.NewOAuthProxy(opts)
if err != nil {
log.Errorf("ERROR: Failed to initialise OAuth2 Proxy: %v", err)
os.Exit(1)
}
pb.proxy = oauthproxy
pb.Handler = chain.Then(oauthproxy)
}

View File

@ -0,0 +1,85 @@
package server
import (
"crypto/tls"
"fmt"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/go-openapi/strfmt"
"github.com/gorilla/websocket"
"github.com/recws-org/recws"
)
func (ac *APIController) initWS(pbURL url.URL, outpostUUID strfmt.UUID) {
pathTemplate := "%s://%s/ws/outpost/%s/"
scheme := strings.ReplaceAll(pbURL.Scheme, "http", "ws")
header := http.Header{
"Authorization": []string{ac.token},
}
_, set := os.LookupEnv("PASSBOOK_INSECURE")
ws := recws.RecConn{
// KeepAliveTimeout: 10 * time.Second,
NonVerbose: true,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: set,
},
}
ws.Dial(fmt.Sprintf(pathTemplate, scheme, pbURL.Host, outpostUUID.String()), header)
ac.logger.WithField("outpost", outpostUUID.String()).Debug("connecting to passbook")
ac.wsConn = ws
}
// Shutdown Gracefully stops all workers, disconnects from websocket
func (ac *APIController) Shutdown() {
// Cleanly close the connection by sending a close message and then
// waiting (with timeout) for the server to close the connection.
err := ac.wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
ac.logger.Println("write close:", err)
return
}
return
}
func (ac *APIController) startWSHandler() {
for {
var wsMsg websocketMessage
err := ac.wsConn.ReadJSON(&wsMsg)
if err != nil {
ac.logger.Println("read:", err)
return
}
if wsMsg.Instruction != WebsocketInstructionAck {
ac.logger.Debugf("%+v\n", wsMsg)
}
if wsMsg.Instruction == WebsocketInstructionTriggerUpdate {
err := ac.UpdateIfRequired()
if err != nil {
ac.logger.WithError(err).Debug("Failed to update")
}
}
}
}
func (ac *APIController) startWSHealth() {
for ; true; <-time.Tick(time.Second * 10) {
aliveMsg := websocketMessage{
Instruction: WebsocketInstructionHello,
Args: make(map[string]interface{}),
}
err := ac.wsConn.WriteJSON(aliveMsg)
if err != nil {
ac.logger.Println("write:", err)
return
}
}
}

View File

@ -0,0 +1,17 @@
package server
type websocketInstruction int
const (
// WebsocketInstructionAck Code used to acknowledge a previous message
WebsocketInstructionAck websocketInstruction = 0
// WebsocketInstructionHello Code used to send a healthcheck keepalive
WebsocketInstructionHello websocketInstruction = 1
// WebsocketInstructionTriggerUpdate Code received to trigger a config update
WebsocketInstructionTriggerUpdate websocketInstruction = 2
)
type websocketMessage struct {
Instruction websocketInstruction `json:"instruction"`
Args map[string]interface{} `json:"args"`
}

63
proxy/pkg/server/cert.go Normal file
View File

@ -0,0 +1,63 @@
package server
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"time"
log "github.com/sirupsen/logrus"
)
func generateSelfSignedCert() (tls.Certificate, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
log.Fatalf("Failed to generate private key: %v", err)
return tls.Certificate{}, err
}
keyUsage := x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment
notBefore := time.Now()
notAfter := notBefore.Add(365 * 24 * time.Hour)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
log.Fatalf("Failed to generate serial number: %v", err)
return tls.Certificate{}, err
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"passbook"},
CommonName: "passbook Proxy default certificate",
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: keyUsage,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
template.DNSNames = []string{"*"}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
log.Warning(err)
}
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
log.Warning(err)
}
privPemByes := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})
return tls.X509KeyPair(pemBytes, privPemByes)
}

View File

@ -0,0 +1,123 @@
package server
import (
"bufio"
"errors"
"fmt"
"net"
"net/http"
"time"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
log "github.com/sirupsen/logrus"
)
// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
// code and body size
type responseLogger struct {
w http.ResponseWriter
status int
size int
upstream string
authInfo string
}
// Header returns the ResponseWriter's Header
func (l *responseLogger) Header() http.Header {
return l.w.Header()
}
// Support Websocket
func (l *responseLogger) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
if hj, ok := l.w.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, errors.New("http.Hijacker is not available on writer")
}
// ExtractGAPMetadata extracts and removes GAP headers from the ResponseWriter's
// Header
func (l *responseLogger) ExtractGAPMetadata() {
upstream := l.w.Header().Get("GAP-Upstream-Address")
if upstream != "" {
l.upstream = upstream
l.w.Header().Del("GAP-Upstream-Address")
}
authInfo := l.w.Header().Get("GAP-Auth")
if authInfo != "" {
l.authInfo = authInfo
l.w.Header().Del("GAP-Auth")
}
}
// Write writes the response using the ResponseWriter
func (l *responseLogger) Write(b []byte) (int, error) {
if l.status == 0 {
// The status will be StatusOK if WriteHeader has not been called yet
l.status = http.StatusOK
}
l.ExtractGAPMetadata()
size, err := l.w.Write(b)
l.size += size
return size, err
}
// WriteHeader writes the status code for the Response
func (l *responseLogger) WriteHeader(s int) {
l.ExtractGAPMetadata()
l.w.WriteHeader(s)
l.status = s
}
// Status returns the response status code
func (l *responseLogger) Status() int {
return l.status
}
// Size returns the response size
func (l *responseLogger) Size() int {
return l.size
}
// Flush sends any buffered data to the client
func (l *responseLogger) Flush() {
if flusher, ok := l.w.(http.Flusher); ok {
flusher.Flush()
}
}
// loggingHandler is the http.Handler implementation for LoggingHandler
type loggingHandler struct {
handler http.Handler
logger *log.Entry
}
// LoggingHandler provides an http.Handler which logs requests to the HTTP server
func LoggingHandler(h http.Handler) http.Handler {
return loggingHandler{
handler: h,
logger: log.WithField("component", "http-server"),
}
}
func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
t := time.Now()
url := *req.URL
responseLogger := &responseLogger{w: w}
h.handler.ServeHTTP(responseLogger, req)
duration := float64(time.Since(t)) / float64(time.Second)
h.logger.WithFields(log.Fields{
"Client": req.RemoteAddr,
"Host": req.Host,
"Protocol": req.Proto,
"RequestDuration": fmt.Sprintf("%0.3f", duration),
"RequestMethod": req.Method,
"ResponseSize": responseLogger.Size(),
"StatusCode": responseLogger.Status(),
"Timestamp": logger.FormatTimestamp(t),
"Upstream": responseLogger.upstream,
"UserAgent": req.UserAgent(),
"Username": responseLogger.authInfo,
}).Info(url.RequestURI())
// logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, , )
}

152
proxy/pkg/server/server.go Normal file
View File

@ -0,0 +1,152 @@
package server
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"time"
log "github.com/sirupsen/logrus"
sentryhttp "github.com/getsentry/sentry-go/http"
)
// Server represents an HTTP server
type Server struct {
Handlers map[string]*providerBundle
stop chan struct{} // channel for waiting shutdown
logger *log.Entry
defaultCert tls.Certificate
}
// NewServer initialise a new HTTP Server
func NewServer() *Server {
defaultCert, err := generateSelfSignedCert()
if err != nil {
log.Warning(err)
}
return &Server{
Handlers: make(map[string]*providerBundle),
logger: log.WithField("component", "http-server"),
defaultCert: defaultCert,
}
}
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
func (s *Server) ServeHTTP() {
// TODO: make this a setting
listenAddress := "localhost:4180"
listener, err := net.Listen("tcp", listenAddress)
if err != nil {
s.logger.Fatalf("FATAL: listen (%s) failed - %s", listenAddress, err)
}
s.logger.Printf("listening on %s", listener.Addr())
s.serve(listener)
s.logger.Printf("closing %s", listener.Addr())
}
func (s *Server) getCertificates(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
handler, ok := s.Handlers[info.ServerName]
if !ok {
s.logger.WithField("server-name", info.ServerName).Debug("Handler does not exist")
return &s.defaultCert, nil
}
if handler.cert == nil {
s.logger.WithField("server-name", info.ServerName).Debug("Handler does not have a certificate")
return &s.defaultCert, nil
}
return handler.cert, nil
}
// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
func (s *Server) ServeHTTPS() {
// TODO: make this a setting
listenAddress := "localhost:4443"
config := &tls.Config{
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS12,
GetCertificate: s.getCertificates,
}
ln, err := net.Listen("tcp", listenAddress)
if err != nil {
s.logger.Fatalf("FATAL: listen (%s) failed - %s", listenAddress, err)
}
s.logger.Printf("listening on %s", ln.Addr())
tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
s.serve(tlsListener)
s.logger.Printf("closing %s", tlsListener.Addr())
}
func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
handler, ok := s.Handlers[r.Host]
if !ok {
// If we only have one handler, host name switching doesn't matter
if len(s.Handlers) == 1 {
for k := range s.Handlers {
s.Handlers[k].ServeHTTP(w, r)
return
}
}
s.logger.WithField("host", r.Host).Debug("Host header does not match any we know of")
s.logger.Printf("%v+\n", s.Handlers)
w.WriteHeader(400)
return
}
s.logger.WithField("host", r.Host).Debug("passing request from host head")
handler.ServeHTTP(w, r)
}
func (s *Server) serve(listener net.Listener) {
sentryHandler := sentryhttp.New(sentryhttp.Options{})
srv := &http.Server{Handler: sentryHandler.HandleFunc(s.handler)}
// See https://golang.org/pkg/net/http/#Server.Shutdown
idleConnsClosed := make(chan struct{})
go func() {
<-s.stop // wait notification for stopping server
// We received an interrupt signal, shut down.
if err := srv.Shutdown(context.Background()); err != nil {
// Error from closing listeners, or context timeout:
s.logger.Printf("HTTP server Shutdown: %v", err)
}
close(idleConnsClosed)
}()
err := srv.Serve(listener)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
s.logger.Errorf("ERROR: http.Serve() - %s", err)
}
<-idleConnsClosed
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe and ListenAndServeTLS so
// dead TCP connections (e.g. closing laptop mid-download) eventually
// go away.
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
err = tc.SetKeepAlive(true)
if err != nil {
log.Printf("Error setting Keep-Alive: %v", err)
}
err = tc.SetKeepAlivePeriod(3 * time.Minute)
if err != nil {
log.Printf("Error setting Keep-Alive period: %v", err)
}
return tc, nil
}