Proxy v2 (#189)
This commit is contained in:
212
proxy/pkg/server/api.go
Normal file
212
proxy/pkg/server/api.go
Normal file
@ -0,0 +1,212 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/sha512"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/BeryJu/passbook/proxy/pkg/client"
|
||||
"github.com/BeryJu/passbook/proxy/pkg/client/outposts"
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/go-openapi/runtime"
|
||||
"github.com/recws-org/recws"
|
||||
|
||||
httptransport "github.com/go-openapi/runtime/client"
|
||||
"github.com/go-openapi/strfmt"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const ConfigLogLevel = "log_level"
|
||||
const ConfigErrorReportingEnabled = "error_reporting_enabled"
|
||||
const ConfigErrorReportingEnvironment = "error_reporting_environment"
|
||||
|
||||
// APIController main controller which connects to the passbook api via http and ws
|
||||
type APIController struct {
|
||||
client *client.Passbook
|
||||
auth runtime.ClientAuthInfoWriter
|
||||
token string
|
||||
|
||||
server *Server
|
||||
|
||||
commonOpts *options.Options
|
||||
|
||||
lastBundleHash string
|
||||
logger *log.Entry
|
||||
|
||||
wsConn recws.RecConn
|
||||
}
|
||||
|
||||
func getCommonOptions() *options.Options {
|
||||
commonOpts := options.NewOptions()
|
||||
commonOpts.Cookie.Name = "passbook_proxy"
|
||||
commonOpts.EmailDomains = []string{"*"}
|
||||
commonOpts.ProviderType = "oidc"
|
||||
commonOpts.ProxyPrefix = "/pbprox"
|
||||
commonOpts.SkipProviderButton = true
|
||||
commonOpts.Logging.SilencePing = true
|
||||
commonOpts.SetXAuthRequest = true
|
||||
commonOpts.SetAuthorization = true
|
||||
return commonOpts
|
||||
}
|
||||
|
||||
func doGlobalSetup(config map[string]interface{}) {
|
||||
switch config[ConfigLogLevel].(string) {
|
||||
case "debug":
|
||||
log.SetLevel(log.DebugLevel)
|
||||
case "info":
|
||||
log.SetLevel(log.InfoLevel)
|
||||
case "warning":
|
||||
log.SetLevel(log.WarnLevel)
|
||||
case "error":
|
||||
log.SetLevel(log.ErrorLevel)
|
||||
default:
|
||||
log.SetLevel(log.DebugLevel)
|
||||
}
|
||||
|
||||
var dsn string
|
||||
if config[ConfigErrorReportingEnabled].(bool) {
|
||||
dsn = "https://33cdbcb23f8b436dbe0ee06847410b67@sentry.beryju.org/3"
|
||||
log.Debug("Error reporting enabled")
|
||||
}
|
||||
|
||||
err := sentry.Init(sentry.ClientOptions{
|
||||
Dsn: dsn,
|
||||
Environment: config[ConfigErrorReportingEnvironment].(string),
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("sentry.Init: %s", err)
|
||||
}
|
||||
|
||||
defer sentry.Flush(2 * time.Second)
|
||||
}
|
||||
|
||||
func getTLSTransport() http.RoundTripper {
|
||||
_, set := os.LookupEnv("PASSBOOK_INSECURE")
|
||||
tlsTransport, err := httptransport.TLSTransport(httptransport.TLSClientOptions{
|
||||
InsecureSkipVerify: set,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return tlsTransport
|
||||
}
|
||||
|
||||
// NewAPIController initialise new API Controller instance from URL and API token
|
||||
func NewAPIController(pbURL url.URL, token string) *APIController {
|
||||
transport := httptransport.New(pbURL.Host, client.DefaultBasePath, []string{pbURL.Scheme})
|
||||
|
||||
transport.Transport = getTLSTransport()
|
||||
|
||||
// create the transport
|
||||
auth := httptransport.BasicAuth("", token)
|
||||
|
||||
// create the API client, with the transport
|
||||
apiClient := client.New(transport, strfmt.Default)
|
||||
|
||||
// Because we don't know the outpost UUID, we simply do a list and pick the first
|
||||
// The service account this token belongs to should only have access to a single outpost
|
||||
outposts, err := apiClient.Outposts.OutpostsOutpostsList(outposts.NewOutpostsOutpostsListParams(), auth)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
outpost := outposts.Payload.Results[0]
|
||||
doGlobalSetup(outpost.Config.(map[string]interface{}))
|
||||
|
||||
ac := &APIController{
|
||||
client: apiClient,
|
||||
auth: auth,
|
||||
token: token,
|
||||
|
||||
logger: log.WithField("component", "api-controller"),
|
||||
commonOpts: getCommonOptions(),
|
||||
server: NewServer(),
|
||||
|
||||
lastBundleHash: "",
|
||||
}
|
||||
ac.initWS(pbURL, outpost.Pk)
|
||||
return ac
|
||||
}
|
||||
|
||||
func (a *APIController) bundleProviders() ([]*providerBundle, error) {
|
||||
providers, err := a.client.Outposts.OutpostsProxyList(outposts.NewOutpostsProxyListParams(), a.auth)
|
||||
if err != nil {
|
||||
a.logger.WithError(err).Error("Failed to fetch providers")
|
||||
return nil, err
|
||||
}
|
||||
// Check provider hash to see if anything is changed
|
||||
hasher := sha512.New()
|
||||
bin, _ := providers.Payload.MarshalBinary()
|
||||
hash := hex.EncodeToString(hasher.Sum(bin))
|
||||
if hash == a.lastBundleHash {
|
||||
return nil, nil
|
||||
}
|
||||
a.lastBundleHash = hash
|
||||
|
||||
bundles := make([]*providerBundle, len(providers.Payload.Results))
|
||||
|
||||
for idx, provider := range providers.Payload.Results {
|
||||
externalHost, err := url.Parse(*provider.ExternalHost)
|
||||
if err != nil {
|
||||
log.WithError(err).Warning("Failed to parse URL, skipping provider")
|
||||
}
|
||||
bundles[idx] = &providerBundle{
|
||||
a: a,
|
||||
Host: externalHost.Hostname(),
|
||||
}
|
||||
bundles[idx].Build(provider)
|
||||
}
|
||||
return bundles, nil
|
||||
}
|
||||
|
||||
func (a *APIController) updateHTTPServer(bundles []*providerBundle) {
|
||||
newMap := make(map[string]*providerBundle)
|
||||
for _, bundle := range bundles {
|
||||
newMap[bundle.Host] = bundle
|
||||
}
|
||||
a.logger.Debug("Swapped maps")
|
||||
a.server.Handlers = newMap
|
||||
}
|
||||
|
||||
// UpdateIfRequired Updates the HTTP Server config if required, automatically swaps the handlers
|
||||
func (a *APIController) UpdateIfRequired() error {
|
||||
bundles, err := a.bundleProviders()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if bundles == nil {
|
||||
a.logger.Debug("Providers have not changed, not updating")
|
||||
return nil
|
||||
}
|
||||
a.updateHTTPServer(bundles)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start Starts all handlers, non-blocking
|
||||
func (a *APIController) Start() error {
|
||||
err := a.UpdateIfRequired()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
a.logger.Debug("Starting HTTP Server...")
|
||||
a.server.ServeHTTP()
|
||||
}()
|
||||
go func() {
|
||||
a.logger.Debug("Starting HTTPs Server...")
|
||||
a.server.ServeHTTPS()
|
||||
}()
|
||||
go func() {
|
||||
a.logger.Debug("Starting WS Handler...")
|
||||
a.startWSHandler()
|
||||
}()
|
||||
go func() {
|
||||
a.logger.Debug("Starting WS Health notifier...")
|
||||
a.startWSHealth()
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
123
proxy/pkg/server/api_bundle.go
Normal file
123
proxy/pkg/server/api_bundle.go
Normal file
@ -0,0 +1,123 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"github.com/BeryJu/passbook/proxy/pkg/client/crypto"
|
||||
"github.com/BeryJu/passbook/proxy/pkg/models"
|
||||
"github.com/BeryJu/passbook/proxy/pkg/proxy"
|
||||
"github.com/jinzhu/copier"
|
||||
"github.com/justinas/alice"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/validation"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type providerBundle struct {
|
||||
http.Handler
|
||||
|
||||
a *APIController
|
||||
proxy *proxy.OAuthProxy
|
||||
Host string
|
||||
|
||||
cert *tls.Certificate
|
||||
}
|
||||
|
||||
func (pb *providerBundle) prepareOpts(provider *models.ProxyOutpostConfig) *options.Options {
|
||||
externalHost, err := url.Parse(*provider.ExternalHost)
|
||||
if err != nil {
|
||||
log.WithError(err).Warning("Failed to parse URL, skipping provider")
|
||||
return nil
|
||||
}
|
||||
providerOpts := &options.Options{}
|
||||
copier.Copy(&providerOpts, &pb.a.commonOpts)
|
||||
providerOpts.ClientID = provider.ClientID
|
||||
providerOpts.ClientSecret = provider.ClientSecret
|
||||
|
||||
providerOpts.Cookie.Secret = provider.CookieSecret
|
||||
providerOpts.Cookie.Secure = externalHost.Scheme == "https"
|
||||
|
||||
providerOpts.SkipOIDCDiscovery = true
|
||||
providerOpts.OIDCIssuerURL = *provider.OidcConfiguration.Issuer
|
||||
providerOpts.LoginURL = *provider.OidcConfiguration.AuthorizationEndpoint
|
||||
providerOpts.RedeemURL = *provider.OidcConfiguration.TokenEndpoint
|
||||
providerOpts.OIDCJwksURL = *provider.OidcConfiguration.JwksURI
|
||||
providerOpts.ProfileURL = *provider.OidcConfiguration.UserinfoEndpoint
|
||||
|
||||
providerOpts.UpstreamServers = []options.Upstream{
|
||||
{
|
||||
ID: "default",
|
||||
URI: *provider.InternalHost,
|
||||
Path: "/",
|
||||
},
|
||||
}
|
||||
|
||||
if provider.Certificate != nil {
|
||||
pb.a.logger.WithField("provider", provider.ClientID).Debug("Enabling TLS")
|
||||
cert, err := pb.a.client.Crypto.CryptoCertificatekeypairsRead(&crypto.CryptoCertificatekeypairsReadParams{
|
||||
Context: context.Background(),
|
||||
KpUUID: *provider.Certificate,
|
||||
}, pb.a.auth)
|
||||
if err != nil {
|
||||
pb.a.logger.WithField("provider", provider.ClientID).WithError(err).Warning("Failed to fetch certificate")
|
||||
return providerOpts
|
||||
}
|
||||
x509cert, err := tls.X509KeyPair([]byte(*cert.Payload.CertificateData), []byte(cert.Payload.KeyData))
|
||||
if err != nil {
|
||||
pb.a.logger.WithField("provider", provider.ClientID).WithError(err).Warning("Failed to parse certificate")
|
||||
return providerOpts
|
||||
}
|
||||
pb.cert = &x509cert
|
||||
pb.a.logger.WithField("provider", provider.ClientID).WithField("certificate-key-pair", *cert.Payload.Name).Debug("Loaded certificates")
|
||||
}
|
||||
return providerOpts
|
||||
}
|
||||
|
||||
func (pb *providerBundle) Build(provider *models.ProxyOutpostConfig) {
|
||||
opts := pb.prepareOpts(provider)
|
||||
|
||||
chain := alice.New()
|
||||
|
||||
if opts.ForceHTTPS {
|
||||
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
|
||||
if err != nil {
|
||||
log.Fatalf("FATAL: invalid HTTPS address %q: %v", opts.HTTPAddress, err)
|
||||
}
|
||||
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
|
||||
}
|
||||
|
||||
healthCheckPaths := []string{opts.PingPath}
|
||||
healthCheckUserAgents := []string{opts.PingUserAgent}
|
||||
if opts.GCPHealthChecks {
|
||||
healthCheckPaths = append(healthCheckPaths, "/liveness_check", "/readiness_check")
|
||||
healthCheckUserAgents = append(healthCheckUserAgents, "GoogleHC/1.0")
|
||||
}
|
||||
|
||||
// To silence logging of health checks, register the health check handler before
|
||||
// the logging handler
|
||||
if opts.Logging.SilencePing {
|
||||
chain = chain.Append(middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), LoggingHandler)
|
||||
} else {
|
||||
chain = chain.Append(LoggingHandler, middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents))
|
||||
}
|
||||
|
||||
err := validation.Validate(opts)
|
||||
if err != nil {
|
||||
log.Printf("%s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
oauthproxy, err := proxy.NewOAuthProxy(opts)
|
||||
if err != nil {
|
||||
log.Errorf("ERROR: Failed to initialise OAuth2 Proxy: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
pb.proxy = oauthproxy
|
||||
pb.Handler = chain.Then(oauthproxy)
|
||||
}
|
||||
85
proxy/pkg/server/api_ws.go
Normal file
85
proxy/pkg/server/api_ws.go
Normal file
@ -0,0 +1,85 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-openapi/strfmt"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/recws-org/recws"
|
||||
)
|
||||
|
||||
func (ac *APIController) initWS(pbURL url.URL, outpostUUID strfmt.UUID) {
|
||||
pathTemplate := "%s://%s/ws/outpost/%s/"
|
||||
scheme := strings.ReplaceAll(pbURL.Scheme, "http", "ws")
|
||||
|
||||
header := http.Header{
|
||||
"Authorization": []string{ac.token},
|
||||
}
|
||||
|
||||
_, set := os.LookupEnv("PASSBOOK_INSECURE")
|
||||
|
||||
ws := recws.RecConn{
|
||||
// KeepAliveTimeout: 10 * time.Second,
|
||||
NonVerbose: true,
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: set,
|
||||
},
|
||||
}
|
||||
ws.Dial(fmt.Sprintf(pathTemplate, scheme, pbURL.Host, outpostUUID.String()), header)
|
||||
|
||||
ac.logger.WithField("outpost", outpostUUID.String()).Debug("connecting to passbook")
|
||||
|
||||
ac.wsConn = ws
|
||||
}
|
||||
|
||||
// Shutdown Gracefully stops all workers, disconnects from websocket
|
||||
func (ac *APIController) Shutdown() {
|
||||
// Cleanly close the connection by sending a close message and then
|
||||
// waiting (with timeout) for the server to close the connection.
|
||||
err := ac.wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
if err != nil {
|
||||
ac.logger.Println("write close:", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ac *APIController) startWSHandler() {
|
||||
for {
|
||||
var wsMsg websocketMessage
|
||||
err := ac.wsConn.ReadJSON(&wsMsg)
|
||||
if err != nil {
|
||||
ac.logger.Println("read:", err)
|
||||
return
|
||||
}
|
||||
if wsMsg.Instruction != WebsocketInstructionAck {
|
||||
ac.logger.Debugf("%+v\n", wsMsg)
|
||||
}
|
||||
if wsMsg.Instruction == WebsocketInstructionTriggerUpdate {
|
||||
err := ac.UpdateIfRequired()
|
||||
if err != nil {
|
||||
ac.logger.WithError(err).Debug("Failed to update")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ac *APIController) startWSHealth() {
|
||||
for ; true; <-time.Tick(time.Second * 10) {
|
||||
aliveMsg := websocketMessage{
|
||||
Instruction: WebsocketInstructionHello,
|
||||
Args: make(map[string]interface{}),
|
||||
}
|
||||
err := ac.wsConn.WriteJSON(aliveMsg)
|
||||
if err != nil {
|
||||
ac.logger.Println("write:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
17
proxy/pkg/server/api_ws_msg.go
Normal file
17
proxy/pkg/server/api_ws_msg.go
Normal file
@ -0,0 +1,17 @@
|
||||
package server
|
||||
|
||||
type websocketInstruction int
|
||||
|
||||
const (
|
||||
// WebsocketInstructionAck Code used to acknowledge a previous message
|
||||
WebsocketInstructionAck websocketInstruction = 0
|
||||
// WebsocketInstructionHello Code used to send a healthcheck keepalive
|
||||
WebsocketInstructionHello websocketInstruction = 1
|
||||
// WebsocketInstructionTriggerUpdate Code received to trigger a config update
|
||||
WebsocketInstructionTriggerUpdate websocketInstruction = 2
|
||||
)
|
||||
|
||||
type websocketMessage struct {
|
||||
Instruction websocketInstruction `json:"instruction"`
|
||||
Args map[string]interface{} `json:"args"`
|
||||
}
|
||||
63
proxy/pkg/server/cert.go
Normal file
63
proxy/pkg/server/cert.go
Normal file
@ -0,0 +1,63 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func generateSelfSignedCert() (tls.Certificate, error) {
|
||||
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to generate private key: %v", err)
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
|
||||
keyUsage := x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(365 * 24 * time.Hour)
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to generate serial number: %v", err)
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"passbook"},
|
||||
CommonName: "passbook Proxy default certificate",
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
|
||||
KeyUsage: keyUsage,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
template.DNSNames = []string{"*"}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
log.Warning(err)
|
||||
}
|
||||
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
|
||||
if err != nil {
|
||||
log.Warning(err)
|
||||
}
|
||||
privPemByes := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})
|
||||
return tls.X509KeyPair(pemBytes, privPemByes)
|
||||
}
|
||||
123
proxy/pkg/server/middleware.go
Normal file
123
proxy/pkg/server/middleware.go
Normal file
@ -0,0 +1,123 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
|
||||
// code and body size
|
||||
type responseLogger struct {
|
||||
w http.ResponseWriter
|
||||
status int
|
||||
size int
|
||||
upstream string
|
||||
authInfo string
|
||||
}
|
||||
|
||||
// Header returns the ResponseWriter's Header
|
||||
func (l *responseLogger) Header() http.Header {
|
||||
return l.w.Header()
|
||||
}
|
||||
|
||||
// Support Websocket
|
||||
func (l *responseLogger) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
|
||||
if hj, ok := l.w.(http.Hijacker); ok {
|
||||
return hj.Hijack()
|
||||
}
|
||||
return nil, nil, errors.New("http.Hijacker is not available on writer")
|
||||
}
|
||||
|
||||
// ExtractGAPMetadata extracts and removes GAP headers from the ResponseWriter's
|
||||
// Header
|
||||
func (l *responseLogger) ExtractGAPMetadata() {
|
||||
upstream := l.w.Header().Get("GAP-Upstream-Address")
|
||||
if upstream != "" {
|
||||
l.upstream = upstream
|
||||
l.w.Header().Del("GAP-Upstream-Address")
|
||||
}
|
||||
authInfo := l.w.Header().Get("GAP-Auth")
|
||||
if authInfo != "" {
|
||||
l.authInfo = authInfo
|
||||
l.w.Header().Del("GAP-Auth")
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes the response using the ResponseWriter
|
||||
func (l *responseLogger) Write(b []byte) (int, error) {
|
||||
if l.status == 0 {
|
||||
// The status will be StatusOK if WriteHeader has not been called yet
|
||||
l.status = http.StatusOK
|
||||
}
|
||||
l.ExtractGAPMetadata()
|
||||
size, err := l.w.Write(b)
|
||||
l.size += size
|
||||
return size, err
|
||||
}
|
||||
|
||||
// WriteHeader writes the status code for the Response
|
||||
func (l *responseLogger) WriteHeader(s int) {
|
||||
l.ExtractGAPMetadata()
|
||||
l.w.WriteHeader(s)
|
||||
l.status = s
|
||||
}
|
||||
|
||||
// Status returns the response status code
|
||||
func (l *responseLogger) Status() int {
|
||||
return l.status
|
||||
}
|
||||
|
||||
// Size returns the response size
|
||||
func (l *responseLogger) Size() int {
|
||||
return l.size
|
||||
}
|
||||
|
||||
// Flush sends any buffered data to the client
|
||||
func (l *responseLogger) Flush() {
|
||||
if flusher, ok := l.w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// loggingHandler is the http.Handler implementation for LoggingHandler
|
||||
type loggingHandler struct {
|
||||
handler http.Handler
|
||||
logger *log.Entry
|
||||
}
|
||||
|
||||
// LoggingHandler provides an http.Handler which logs requests to the HTTP server
|
||||
func LoggingHandler(h http.Handler) http.Handler {
|
||||
return loggingHandler{
|
||||
handler: h,
|
||||
logger: log.WithField("component", "http-server"),
|
||||
}
|
||||
}
|
||||
|
||||
func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
t := time.Now()
|
||||
url := *req.URL
|
||||
responseLogger := &responseLogger{w: w}
|
||||
h.handler.ServeHTTP(responseLogger, req)
|
||||
duration := float64(time.Since(t)) / float64(time.Second)
|
||||
h.logger.WithFields(log.Fields{
|
||||
"Client": req.RemoteAddr,
|
||||
"Host": req.Host,
|
||||
"Protocol": req.Proto,
|
||||
"RequestDuration": fmt.Sprintf("%0.3f", duration),
|
||||
"RequestMethod": req.Method,
|
||||
"ResponseSize": responseLogger.Size(),
|
||||
"StatusCode": responseLogger.Status(),
|
||||
"Timestamp": logger.FormatTimestamp(t),
|
||||
"Upstream": responseLogger.upstream,
|
||||
"UserAgent": req.UserAgent(),
|
||||
"Username": responseLogger.authInfo,
|
||||
}).Info(url.RequestURI())
|
||||
// logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, , )
|
||||
}
|
||||
152
proxy/pkg/server/server.go
Normal file
152
proxy/pkg/server/server.go
Normal file
@ -0,0 +1,152 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
sentryhttp "github.com/getsentry/sentry-go/http"
|
||||
)
|
||||
|
||||
// Server represents an HTTP server
|
||||
type Server struct {
|
||||
Handlers map[string]*providerBundle
|
||||
|
||||
stop chan struct{} // channel for waiting shutdown
|
||||
logger *log.Entry
|
||||
|
||||
defaultCert tls.Certificate
|
||||
}
|
||||
|
||||
// NewServer initialise a new HTTP Server
|
||||
func NewServer() *Server {
|
||||
defaultCert, err := generateSelfSignedCert()
|
||||
if err != nil {
|
||||
log.Warning(err)
|
||||
}
|
||||
return &Server{
|
||||
Handlers: make(map[string]*providerBundle),
|
||||
logger: log.WithField("component", "http-server"),
|
||||
defaultCert: defaultCert,
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
|
||||
func (s *Server) ServeHTTP() {
|
||||
// TODO: make this a setting
|
||||
listenAddress := "localhost:4180"
|
||||
listener, err := net.Listen("tcp", listenAddress)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("FATAL: listen (%s) failed - %s", listenAddress, err)
|
||||
}
|
||||
s.logger.Printf("listening on %s", listener.Addr())
|
||||
s.serve(listener)
|
||||
s.logger.Printf("closing %s", listener.Addr())
|
||||
}
|
||||
|
||||
func (s *Server) getCertificates(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
handler, ok := s.Handlers[info.ServerName]
|
||||
if !ok {
|
||||
s.logger.WithField("server-name", info.ServerName).Debug("Handler does not exist")
|
||||
return &s.defaultCert, nil
|
||||
}
|
||||
if handler.cert == nil {
|
||||
s.logger.WithField("server-name", info.ServerName).Debug("Handler does not have a certificate")
|
||||
return &s.defaultCert, nil
|
||||
}
|
||||
return handler.cert, nil
|
||||
}
|
||||
|
||||
// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
|
||||
func (s *Server) ServeHTTPS() {
|
||||
// TODO: make this a setting
|
||||
listenAddress := "localhost:4443"
|
||||
config := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MaxVersion: tls.VersionTLS12,
|
||||
GetCertificate: s.getCertificates,
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", listenAddress)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("FATAL: listen (%s) failed - %s", listenAddress, err)
|
||||
}
|
||||
s.logger.Printf("listening on %s", ln.Addr())
|
||||
|
||||
tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
|
||||
s.serve(tlsListener)
|
||||
s.logger.Printf("closing %s", tlsListener.Addr())
|
||||
}
|
||||
|
||||
func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
|
||||
handler, ok := s.Handlers[r.Host]
|
||||
if !ok {
|
||||
// If we only have one handler, host name switching doesn't matter
|
||||
if len(s.Handlers) == 1 {
|
||||
for k := range s.Handlers {
|
||||
s.Handlers[k].ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
s.logger.WithField("host", r.Host).Debug("Host header does not match any we know of")
|
||||
s.logger.Printf("%v+\n", s.Handlers)
|
||||
w.WriteHeader(400)
|
||||
return
|
||||
}
|
||||
s.logger.WithField("host", r.Host).Debug("passing request from host head")
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (s *Server) serve(listener net.Listener) {
|
||||
sentryHandler := sentryhttp.New(sentryhttp.Options{})
|
||||
|
||||
srv := &http.Server{Handler: sentryHandler.HandleFunc(s.handler)}
|
||||
|
||||
// See https://golang.org/pkg/net/http/#Server.Shutdown
|
||||
idleConnsClosed := make(chan struct{})
|
||||
go func() {
|
||||
<-s.stop // wait notification for stopping server
|
||||
|
||||
// We received an interrupt signal, shut down.
|
||||
if err := srv.Shutdown(context.Background()); err != nil {
|
||||
// Error from closing listeners, or context timeout:
|
||||
s.logger.Printf("HTTP server Shutdown: %v", err)
|
||||
}
|
||||
close(idleConnsClosed)
|
||||
}()
|
||||
|
||||
err := srv.Serve(listener)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.logger.Errorf("ERROR: http.Serve() - %s", err)
|
||||
}
|
||||
<-idleConnsClosed
|
||||
}
|
||||
|
||||
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
|
||||
// connections. It's used by ListenAndServe and ListenAndServeTLS so
|
||||
// dead TCP connections (e.g. closing laptop mid-download) eventually
|
||||
// go away.
|
||||
type tcpKeepAliveListener struct {
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
|
||||
tc, err := ln.AcceptTCP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = tc.SetKeepAlive(true)
|
||||
if err != nil {
|
||||
log.Printf("Error setting Keep-Alive: %v", err)
|
||||
}
|
||||
err = tc.SetKeepAlivePeriod(3 * time.Minute)
|
||||
if err != nil {
|
||||
log.Printf("Error setting Keep-Alive period: %v", err)
|
||||
}
|
||||
return tc, nil
|
||||
}
|
||||
Reference in New Issue
Block a user