outpost: separate ak-api and proxy further for future outposts
This commit is contained in:
		
							
								
								
									
										48
									
								
								outpost/pkg/proxy/api.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								outpost/pkg/proxy/api.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,48 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/url"
 | 
			
		||||
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg/models"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (s *Server) Refresh() error {
 | 
			
		||||
	providers, err := s.ak.Update()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if providers == nil {
 | 
			
		||||
		s.logger.Debug("Providers have not changed, not updating")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	bundles := s.bundleProviders(providers)
 | 
			
		||||
	s.updateHTTPServer(bundles)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) bundleProviders(providers []*models.ProxyOutpostConfig) []*providerBundle {
 | 
			
		||||
	bundles := make([]*providerBundle, len(providers))
 | 
			
		||||
	for idx, provider := range providers {
 | 
			
		||||
		externalHost, err := url.Parse(*provider.ExternalHost)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.WithError(err).Warning("Failed to parse URL, skipping provider")
 | 
			
		||||
		}
 | 
			
		||||
		bundles[idx] = &providerBundle{
 | 
			
		||||
			s:    s,
 | 
			
		||||
			Host: externalHost.Host,
 | 
			
		||||
			log:  log.WithField("component", "proxy-bundle").WithField("provider", provider.Name),
 | 
			
		||||
		}
 | 
			
		||||
		bundles[idx].Build(provider)
 | 
			
		||||
	}
 | 
			
		||||
	return bundles
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) updateHTTPServer(bundles []*providerBundle) {
 | 
			
		||||
	newMap := make(map[string]*providerBundle)
 | 
			
		||||
	for _, bundle := range bundles {
 | 
			
		||||
		newMap[bundle.Host] = bundle
 | 
			
		||||
	}
 | 
			
		||||
	s.logger.Debug("Swapped maps")
 | 
			
		||||
	s.Handlers = newMap
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										137
									
								
								outpost/pkg/proxy/api_bundle.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										137
									
								
								outpost/pkg/proxy/api_bundle.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,137 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg/client/crypto"
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg/models"
 | 
			
		||||
	"github.com/jinzhu/copier"
 | 
			
		||||
	"github.com/justinas/alice"
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware"
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/validation"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type providerBundle struct {
 | 
			
		||||
	http.Handler
 | 
			
		||||
 | 
			
		||||
	s     *Server
 | 
			
		||||
	proxy *OAuthProxy
 | 
			
		||||
	Host  string
 | 
			
		||||
 | 
			
		||||
	cert *tls.Certificate
 | 
			
		||||
 | 
			
		||||
	log *log.Entry
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pb *providerBundle) prepareOpts(provider *models.ProxyOutpostConfig) *options.Options {
 | 
			
		||||
	externalHost, err := url.Parse(*provider.ExternalHost)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.WithError(err).Warning("Failed to parse URL, skipping provider")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	providerOpts := &options.Options{}
 | 
			
		||||
	copier.Copy(&providerOpts, getCommonOptions())
 | 
			
		||||
	providerOpts.ClientID = provider.ClientID
 | 
			
		||||
	providerOpts.ClientSecret = provider.ClientSecret
 | 
			
		||||
 | 
			
		||||
	providerOpts.Cookie.Secret = provider.CookieSecret
 | 
			
		||||
	providerOpts.Cookie.Secure = externalHost.Scheme == "https"
 | 
			
		||||
 | 
			
		||||
	providerOpts.SkipOIDCDiscovery = true
 | 
			
		||||
	providerOpts.OIDCIssuerURL = *provider.OidcConfiguration.Issuer
 | 
			
		||||
	providerOpts.LoginURL = *provider.OidcConfiguration.AuthorizationEndpoint
 | 
			
		||||
	providerOpts.RedeemURL = *provider.OidcConfiguration.TokenEndpoint
 | 
			
		||||
	providerOpts.OIDCJwksURL = *provider.OidcConfiguration.JwksURI
 | 
			
		||||
	providerOpts.ProfileURL = *provider.OidcConfiguration.UserinfoEndpoint
 | 
			
		||||
 | 
			
		||||
	if provider.SkipPathRegex != "" {
 | 
			
		||||
		skipRegexes := strings.Split(provider.SkipPathRegex, "\n")
 | 
			
		||||
		providerOpts.SkipAuthRegex = skipRegexes
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	providerOpts.UpstreamServers = []options.Upstream{
 | 
			
		||||
		{
 | 
			
		||||
			ID:                    "default",
 | 
			
		||||
			URI:                   *provider.InternalHost,
 | 
			
		||||
			Path:                  "/",
 | 
			
		||||
			InsecureSkipTLSVerify: *&provider.InternalHostSslValidation,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if provider.Certificate != nil {
 | 
			
		||||
		pb.log.WithField("provider", provider.ClientID).Debug("Enabling TLS")
 | 
			
		||||
		cert, err := pb.s.ak.Client.Crypto.CryptoCertificatekeypairsRead(&crypto.CryptoCertificatekeypairsReadParams{
 | 
			
		||||
			Context: context.Background(),
 | 
			
		||||
			KpUUID:  *provider.Certificate,
 | 
			
		||||
		}, pb.s.ak.Auth)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			pb.log.WithField("provider", provider.ClientID).WithError(err).Warning("Failed to fetch certificate")
 | 
			
		||||
			return providerOpts
 | 
			
		||||
		}
 | 
			
		||||
		x509cert, err := tls.X509KeyPair([]byte(*cert.Payload.CertificateData), []byte(cert.Payload.KeyData))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			pb.log.WithField("provider", provider.ClientID).WithError(err).Warning("Failed to parse certificate")
 | 
			
		||||
			return providerOpts
 | 
			
		||||
		}
 | 
			
		||||
		pb.cert = &x509cert
 | 
			
		||||
		pb.log.WithField("provider", provider.ClientID).WithField("certificate-key-pair", *cert.Payload.Name).Debug("Loaded certificates")
 | 
			
		||||
	}
 | 
			
		||||
	return providerOpts
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pb *providerBundle) Build(provider *models.ProxyOutpostConfig) {
 | 
			
		||||
	opts := pb.prepareOpts(provider)
 | 
			
		||||
 | 
			
		||||
	chain := alice.New()
 | 
			
		||||
 | 
			
		||||
	if opts.ForceHTTPS {
 | 
			
		||||
		_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Fatalf("FATAL: invalid HTTPS address %q: %v", opts.HTTPAddress, err)
 | 
			
		||||
		}
 | 
			
		||||
		chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	healthCheckPaths := []string{opts.PingPath}
 | 
			
		||||
	healthCheckUserAgents := []string{opts.PingUserAgent}
 | 
			
		||||
	if opts.GCPHealthChecks {
 | 
			
		||||
		healthCheckPaths = append(healthCheckPaths, "/liveness_check", "/readiness_check")
 | 
			
		||||
		healthCheckUserAgents = append(healthCheckUserAgents, "GoogleHC/1.0")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// To silence logging of health checks, register the health check handler before
 | 
			
		||||
	// the logging handler
 | 
			
		||||
	if opts.Logging.SilencePing {
 | 
			
		||||
		chain = chain.Append(middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), LoggingHandler)
 | 
			
		||||
	} else {
 | 
			
		||||
		chain = chain.Append(LoggingHandler, middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := validation.Validate(opts)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Printf("%s", err)
 | 
			
		||||
		os.Exit(1)
 | 
			
		||||
	}
 | 
			
		||||
	oauthproxy, err := NewOAuthProxy(opts)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Errorf("ERROR: Failed to initialise OAuth2 Proxy: %v", err)
 | 
			
		||||
		os.Exit(1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if *&provider.BasicAuthEnabled {
 | 
			
		||||
		oauthproxy.SetBasicAuth = true
 | 
			
		||||
		oauthproxy.BasicAuthUserAttribute = provider.BasicAuthUserAttribute
 | 
			
		||||
		oauthproxy.BasicAuthPasswordAttribute = provider.BasicAuthPasswordAttribute
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pb.proxy = oauthproxy
 | 
			
		||||
	pb.Handler = chain.Then(oauthproxy)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										20
									
								
								outpost/pkg/proxy/common.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								outpost/pkg/proxy/common.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,20 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func getCommonOptions() *options.Options {
 | 
			
		||||
	commonOpts := options.NewOptions()
 | 
			
		||||
	commonOpts.Cookie.Name = "authentik_proxy"
 | 
			
		||||
	commonOpts.Cookie.Expire = 24 * time.Hour
 | 
			
		||||
	commonOpts.EmailDomains = []string{"*"}
 | 
			
		||||
	commonOpts.ProviderType = "oidc"
 | 
			
		||||
	commonOpts.ProxyPrefix = "/akprox"
 | 
			
		||||
	commonOpts.Logging.SilencePing = true
 | 
			
		||||
	commonOpts.SetAuthorization = false
 | 
			
		||||
	commonOpts.Scope = "openid email profile ak_proxy"
 | 
			
		||||
	return commonOpts
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										122
									
								
								outpost/pkg/proxy/middleware.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										122
									
								
								outpost/pkg/proxy/middleware.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,122 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
 | 
			
		||||
// code and body size
 | 
			
		||||
type responseLogger struct {
 | 
			
		||||
	w        http.ResponseWriter
 | 
			
		||||
	status   int
 | 
			
		||||
	size     int
 | 
			
		||||
	upstream string
 | 
			
		||||
	authInfo string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Header returns the ResponseWriter's Header
 | 
			
		||||
func (l *responseLogger) Header() http.Header {
 | 
			
		||||
	return l.w.Header()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Support Websocket
 | 
			
		||||
func (l *responseLogger) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
 | 
			
		||||
	if hj, ok := l.w.(http.Hijacker); ok {
 | 
			
		||||
		return hj.Hijack()
 | 
			
		||||
	}
 | 
			
		||||
	return nil, nil, errors.New("http.Hijacker is not available on writer")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ExtractGAPMetadata extracts and removes GAP headers from the ResponseWriter's
 | 
			
		||||
// Header
 | 
			
		||||
func (l *responseLogger) ExtractGAPMetadata() {
 | 
			
		||||
	upstream := l.w.Header().Get("GAP-Upstream-Address")
 | 
			
		||||
	if upstream != "" {
 | 
			
		||||
		l.upstream = upstream
 | 
			
		||||
		l.w.Header().Del("GAP-Upstream-Address")
 | 
			
		||||
	}
 | 
			
		||||
	authInfo := l.w.Header().Get("GAP-Auth")
 | 
			
		||||
	if authInfo != "" {
 | 
			
		||||
		l.authInfo = authInfo
 | 
			
		||||
		l.w.Header().Del("GAP-Auth")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Write writes the response using the ResponseWriter
 | 
			
		||||
func (l *responseLogger) Write(b []byte) (int, error) {
 | 
			
		||||
	if l.status == 0 {
 | 
			
		||||
		// The status will be StatusOK if WriteHeader has not been called yet
 | 
			
		||||
		l.status = http.StatusOK
 | 
			
		||||
	}
 | 
			
		||||
	l.ExtractGAPMetadata()
 | 
			
		||||
	size, err := l.w.Write(b)
 | 
			
		||||
	l.size += size
 | 
			
		||||
	return size, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WriteHeader writes the status code for the Response
 | 
			
		||||
func (l *responseLogger) WriteHeader(s int) {
 | 
			
		||||
	l.ExtractGAPMetadata()
 | 
			
		||||
	l.w.WriteHeader(s)
 | 
			
		||||
	l.status = s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Status returns the response status code
 | 
			
		||||
func (l *responseLogger) Status() int {
 | 
			
		||||
	return l.status
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Size returns the response size
 | 
			
		||||
func (l *responseLogger) Size() int {
 | 
			
		||||
	return l.size
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Flush sends any buffered data to the client
 | 
			
		||||
func (l *responseLogger) Flush() {
 | 
			
		||||
	if flusher, ok := l.w.(http.Flusher); ok {
 | 
			
		||||
		flusher.Flush()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// loggingHandler is the http.Handler implementation for LoggingHandler
 | 
			
		||||
type loggingHandler struct {
 | 
			
		||||
	handler http.Handler
 | 
			
		||||
	logger  *log.Entry
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// LoggingHandler provides an http.Handler which logs requests to the HTTP server
 | 
			
		||||
func LoggingHandler(h http.Handler) http.Handler {
 | 
			
		||||
	return loggingHandler{
 | 
			
		||||
		handler: h,
 | 
			
		||||
		logger:  log.WithField("component", "proxy-http-server"),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
	t := time.Now()
 | 
			
		||||
	url := *req.URL
 | 
			
		||||
	responseLogger := &responseLogger{w: w}
 | 
			
		||||
	h.handler.ServeHTTP(responseLogger, req)
 | 
			
		||||
	duration := float64(time.Since(t)) / float64(time.Second)
 | 
			
		||||
	h.logger.WithFields(log.Fields{
 | 
			
		||||
		"Client":          req.RemoteAddr,
 | 
			
		||||
		"Host":            req.Host,
 | 
			
		||||
		"Protocol":        req.Proto,
 | 
			
		||||
		"RequestDuration": fmt.Sprintf("%0.3f", duration),
 | 
			
		||||
		"RequestMethod":   req.Method,
 | 
			
		||||
		"ResponseSize":    responseLogger.Size(),
 | 
			
		||||
		"StatusCode":      responseLogger.Status(),
 | 
			
		||||
		"Timestamp":       t,
 | 
			
		||||
		"Upstream":        responseLogger.upstream,
 | 
			
		||||
		"UserAgent":       req.UserAgent(),
 | 
			
		||||
		"Username":        responseLogger.authInfo,
 | 
			
		||||
	}).Info(url.RequestURI())
 | 
			
		||||
	// logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, , )
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										107
									
								
								outpost/pkg/proxy/server.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								outpost/pkg/proxy/server.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,107 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/BeryJu/authentik/outpost/pkg/ak"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Server represents an HTTP server
 | 
			
		||||
type Server struct {
 | 
			
		||||
	Handlers map[string]*providerBundle
 | 
			
		||||
 | 
			
		||||
	stop        chan struct{} // channel for waiting shutdown
 | 
			
		||||
	logger      *log.Entry
 | 
			
		||||
	ak          *ak.APIController
 | 
			
		||||
	defaultCert tls.Certificate
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewServer initialise a new HTTP Server
 | 
			
		||||
func NewServer(ac *ak.APIController) *Server {
 | 
			
		||||
	defaultCert, err := ak.GenerateSelfSignedCert()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Warning(err)
 | 
			
		||||
	}
 | 
			
		||||
	return &Server{
 | 
			
		||||
		Handlers:    make(map[string]*providerBundle),
 | 
			
		||||
		logger:      log.WithField("component", "proxy-http-server"),
 | 
			
		||||
		defaultCert: defaultCert,
 | 
			
		||||
		ak:          ac,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	if r.URL.Path == "/akprox/ping" {
 | 
			
		||||
		w.WriteHeader(204)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	handler, ok := s.Handlers[r.Host]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		// If we only have one handler, host name switching doesn't matter
 | 
			
		||||
		if len(s.Handlers) == 1 {
 | 
			
		||||
			for k := range s.Handlers {
 | 
			
		||||
				s.Handlers[k].ServeHTTP(w, r)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		s.logger.WithField("host", r.Host).Debug("Host header does not match any we know of")
 | 
			
		||||
		s.logger.Printf("%v+\n", s.Handlers)
 | 
			
		||||
		w.WriteHeader(400)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	s.logger.WithField("host", r.Host).Debug("passing request from host head")
 | 
			
		||||
	handler.ServeHTTP(w, r)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) serve(listener net.Listener) {
 | 
			
		||||
	srv := &http.Server{Handler: http.HandlerFunc(s.handler)}
 | 
			
		||||
 | 
			
		||||
	// See https://golang.org/pkg/net/http/#Server.Shutdown
 | 
			
		||||
	idleConnsClosed := make(chan struct{})
 | 
			
		||||
	go func() {
 | 
			
		||||
		<-s.stop // wait notification for stopping server
 | 
			
		||||
 | 
			
		||||
		// We received an interrupt signal, shut down.
 | 
			
		||||
		if err := srv.Shutdown(context.Background()); err != nil {
 | 
			
		||||
			// Error from closing listeners, or context timeout:
 | 
			
		||||
			s.logger.Printf("HTTP server Shutdown: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		close(idleConnsClosed)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	err := srv.Serve(listener)
 | 
			
		||||
	if err != nil && !errors.Is(err, http.ErrServerClosed) {
 | 
			
		||||
		s.logger.Errorf("ERROR: http.Serve() - %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	<-idleConnsClosed
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
 | 
			
		||||
// connections. It's used by ListenAndServe and ListenAndServeTLS so
 | 
			
		||||
// dead TCP connections (e.g. closing laptop mid-download) eventually
 | 
			
		||||
// go away.
 | 
			
		||||
type tcpKeepAliveListener struct {
 | 
			
		||||
	*net.TCPListener
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
 | 
			
		||||
	tc, err := ln.AcceptTCP()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	err = tc.SetKeepAlive(true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Printf("Error setting Keep-Alive: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	err = tc.SetKeepAlivePeriod(3 * time.Minute)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Printf("Error setting Keep-Alive period: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	return tc, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										68
									
								
								outpost/pkg/proxy/server_https.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								outpost/pkg/proxy/server_https.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,68 @@
 | 
			
		||||
package proxy
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
 | 
			
		||||
func (s *Server) ServeHTTP() {
 | 
			
		||||
	listenAddress := "0.0.0.0:4180"
 | 
			
		||||
	listener, err := net.Listen("tcp", listenAddress)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Fatalf("FATAL: listen (%s) failed - %s", listenAddress, err)
 | 
			
		||||
	}
 | 
			
		||||
	s.logger.Printf("listening on %s", listener.Addr())
 | 
			
		||||
	s.serve(listener)
 | 
			
		||||
	s.logger.Printf("closing %s", listener.Addr())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) getCertificates(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
 | 
			
		||||
	handler, ok := s.Handlers[info.ServerName]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		s.logger.WithField("server-name", info.ServerName).Debug("Handler does not exist")
 | 
			
		||||
		return &s.defaultCert, nil
 | 
			
		||||
	}
 | 
			
		||||
	if handler.cert == nil {
 | 
			
		||||
		s.logger.WithField("server-name", info.ServerName).Debug("Handler does not have a certificate")
 | 
			
		||||
		return &s.defaultCert, nil
 | 
			
		||||
	}
 | 
			
		||||
	return handler.cert, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
 | 
			
		||||
func (s *Server) ServeHTTPS() {
 | 
			
		||||
	listenAddress := "0.0.0.0:4443"
 | 
			
		||||
	config := &tls.Config{
 | 
			
		||||
		MinVersion:     tls.VersionTLS12,
 | 
			
		||||
		MaxVersion:     tls.VersionTLS12,
 | 
			
		||||
		GetCertificate: s.getCertificates,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ln, err := net.Listen("tcp", listenAddress)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Fatalf("FATAL: listen (%s) failed - %s", listenAddress, err)
 | 
			
		||||
	}
 | 
			
		||||
	s.logger.Printf("listening on %s", ln.Addr())
 | 
			
		||||
 | 
			
		||||
	tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
 | 
			
		||||
	s.serve(tlsListener)
 | 
			
		||||
	s.logger.Printf("closing %s", tlsListener.Addr())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) Start() error {
 | 
			
		||||
	wg := sync.WaitGroup{}
 | 
			
		||||
	wg.Add(2)
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer wg.Done()
 | 
			
		||||
		s.logger.Debug("Starting HTTP Server...")
 | 
			
		||||
		s.ServeHTTP()
 | 
			
		||||
	}()
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer wg.Done()
 | 
			
		||||
		s.logger.Debug("Starting HTTPs Server...")
 | 
			
		||||
		s.ServeHTTPS()
 | 
			
		||||
	}()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user