Proxy v2 (#189)
This commit is contained in:
2
proxy/.dockerignore
Normal file
2
proxy/.dockerignore
Normal file
@ -0,0 +1,2 @@
|
||||
Dockerfile.*
|
||||
.git
|
||||
2
proxy/.gitignore
vendored
Normal file
2
proxy/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
pkg/client/
|
||||
pkg/models/
|
||||
12
proxy/Dockerfile
Normal file
12
proxy/Dockerfile
Normal file
@ -0,0 +1,12 @@
|
||||
FROM golang:1.15 AS builder
|
||||
|
||||
WORKDIR /work
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN go build -o /work/proxy .
|
||||
|
||||
# Copy binary to alpine
|
||||
FROM gcr.io/distroless/base-debian10
|
||||
COPY --from=builder /work/proxy /
|
||||
ENTRYPOINT ["/proxy"]
|
||||
6
proxy/Makefile
Normal file
6
proxy/Makefile
Normal file
@ -0,0 +1,6 @@
|
||||
generate:
|
||||
go get -u github.com/go-swagger/go-swagger/cmd/swagger
|
||||
swagger generate client -f ../swagger.yaml -A passbook -t pkg/
|
||||
|
||||
run:
|
||||
go run -v .
|
||||
24
proxy/README.md
Normal file
24
proxy/README.md
Normal file
@ -0,0 +1,24 @@
|
||||
# passbook Proxy
|
||||
|
||||
[](https://dev.azure.com/beryjuorg/passbook/_build?definitionId=3)
|
||||

|
||||
|
||||
Reverse Proxy based on [oauth2_proxy](https://github.com/oauth2-proxy/oauth2-proxy), completely managed and monitored by passbook.
|
||||
|
||||
## Usage
|
||||
|
||||
passbook Proxy is built to be configured by passbook itself, hence the only options you can directly give it are connection params.
|
||||
|
||||
The following environment variable are implemented:
|
||||
|
||||
`PASSBOOK_HOST`: Full URL to the passbook instance with protocol, i.e. "https://passbook.company.tld"
|
||||
|
||||
`PASSBOOK_TOKEN`: Token used to authenticate against passbook. This is generated after an Outpost instance is created.
|
||||
|
||||
`PASSBOOK_INSECURE`: This environment variable can optionally be set to ignore the SSL Certificate of the passbook instance. Applies to both HTTP and WS connections.
|
||||
|
||||
## Development
|
||||
|
||||
passbook Proxy uses an auto-generated API Client to communicate with passbook. This client is not kept in git. To generate the client locally, run `make generate`.
|
||||
|
||||
Afterwards you can build the proxy like any other Go project, using `go build`.
|
||||
90
proxy/azure-pipelines.yml
Normal file
90
proxy/azure-pipelines.yml
Normal file
@ -0,0 +1,90 @@
|
||||
trigger:
|
||||
- master
|
||||
|
||||
stages:
|
||||
- stage: generate
|
||||
jobs:
|
||||
- job: swagger_generate
|
||||
pool:
|
||||
vmImage: 'ubuntu-latest'
|
||||
steps:
|
||||
- task: GoTool@0
|
||||
inputs:
|
||||
version: '1.15'
|
||||
- task: Go@0
|
||||
inputs:
|
||||
command: 'get'
|
||||
arguments: '-u github.com/go-swagger/go-swagger/cmd/swagger'
|
||||
- task: CmdLine@2
|
||||
inputs:
|
||||
script: |
|
||||
$(go list -f {{.Target}} github.com/go-swagger/go-swagger/cmd/swagger) generate client -f ../swagger.yaml -A passbook -t pkg/
|
||||
workingDirectory: 'proxy/'
|
||||
- task: PublishPipelineArtifact@1
|
||||
inputs:
|
||||
targetPath: 'proxy/pkg/'
|
||||
artifact: 'swagger_client'
|
||||
publishLocation: 'pipeline'
|
||||
- stage: lint
|
||||
jobs:
|
||||
- job: golint
|
||||
pool:
|
||||
vmImage: 'ubuntu-latest'
|
||||
steps:
|
||||
- task: GoTool@0
|
||||
inputs:
|
||||
version: '1.15'
|
||||
- task: Go@0
|
||||
inputs:
|
||||
command: 'get'
|
||||
arguments: '-u golang.org/x/lint/golint'
|
||||
- task: DownloadPipelineArtifact@2
|
||||
inputs:
|
||||
buildType: 'current'
|
||||
artifactName: 'swagger_client'
|
||||
path: "proxy/pkg/"
|
||||
- task: CmdLine@2
|
||||
inputs:
|
||||
script: |
|
||||
$(go list -f {{.Target}} golang.org/x/lint/golint) ./...
|
||||
workingDirectory: 'proxy/'
|
||||
- stage: build_go
|
||||
jobs:
|
||||
- job: build_go
|
||||
pool:
|
||||
vmImage: 'ubuntu-latest'
|
||||
steps:
|
||||
- task: GoTool@0
|
||||
inputs:
|
||||
version: '1.15'
|
||||
- task: DownloadPipelineArtifact@2
|
||||
inputs:
|
||||
buildType: 'current'
|
||||
artifactName: 'swagger_client'
|
||||
path: "proxy/pkg/"
|
||||
- task: Go@0
|
||||
inputs:
|
||||
command: 'build'
|
||||
workingDirectory: 'proxy/'
|
||||
- stage: build_docker
|
||||
jobs:
|
||||
- job: build_proxy
|
||||
pool:
|
||||
vmImage: 'ubuntu-latest'
|
||||
steps:
|
||||
- task: GoTool@0
|
||||
inputs:
|
||||
version: '1.15'
|
||||
- task: DownloadPipelineArtifact@2
|
||||
inputs:
|
||||
buildType: 'current'
|
||||
artifactName: 'swagger_client'
|
||||
path: "proxy/pkg/"
|
||||
- task: Docker@2
|
||||
inputs:
|
||||
containerRegistry: 'dockerhub'
|
||||
repository: 'beryju/passbook-proxy'
|
||||
command: 'buildAndPush'
|
||||
Dockerfile: 'proxy/Dockerfile'
|
||||
buildContext: 'proxy/'
|
||||
tags: 'gh-$(Build.SourceBranchName)'
|
||||
45
proxy/cmd/server.go
Normal file
45
proxy/cmd/server.go
Normal file
@ -0,0 +1,45 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"time"
|
||||
|
||||
"github.com/BeryJu/passbook/proxy/pkg/server"
|
||||
)
|
||||
|
||||
// RunServer main entrypoint, runs the full server
|
||||
func RunServer() {
|
||||
pbURL, found := os.LookupEnv("PASSBOOK_HOST")
|
||||
if !found {
|
||||
panic("env PASSBOOK_HOST not set!")
|
||||
}
|
||||
pbToken, found := os.LookupEnv("PASSBOOK_TOKEN")
|
||||
if !found {
|
||||
panic("env PASSBOOK_TOKEN not set!")
|
||||
}
|
||||
|
||||
pbURLActual, err := url.Parse(pbURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
ac := server.NewAPIController(*pbURLActual, pbToken)
|
||||
|
||||
interrupt := make(chan os.Signal, 1)
|
||||
signal.Notify(interrupt, os.Interrupt)
|
||||
|
||||
ac.Start()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-interrupt:
|
||||
ac.Shutdown()
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
44
proxy/go.mod
Normal file
44
proxy/go.mod
Normal file
@ -0,0 +1,44 @@
|
||||
module github.com/BeryJu/passbook/proxy
|
||||
|
||||
go 1.14
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.64.0 // indirect
|
||||
github.com/asaskevich/govalidator v0.0.0-20200819183940-29e1ff8eb0bb // indirect
|
||||
github.com/coreos/go-oidc v2.2.1+incompatible
|
||||
github.com/getsentry/sentry-go v0.7.0
|
||||
github.com/go-openapi/errors v0.19.6
|
||||
github.com/go-openapi/runtime v0.19.21
|
||||
github.com/go-openapi/spec v0.19.9 // indirect
|
||||
github.com/go-openapi/strfmt v0.19.5
|
||||
github.com/go-openapi/swag v0.19.9
|
||||
github.com/go-openapi/validate v0.19.10
|
||||
github.com/go-redis/redis/v7 v7.4.0 // indirect
|
||||
github.com/go-swagger/go-swagger v0.25.0 // indirect
|
||||
github.com/gorilla/handlers v1.5.0 // indirect
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a
|
||||
github.com/justinas/alice v1.2.0
|
||||
github.com/kr/pretty v0.2.1 // indirect
|
||||
github.com/magiconair/properties v1.8.2 // indirect
|
||||
github.com/mailru/easyjson v0.7.6 // indirect
|
||||
github.com/mitchellh/mapstructure v1.3.3 // indirect
|
||||
github.com/oauth2-proxy/oauth2-proxy v1.1.2-0.20200817154438-5fa5b3186f39
|
||||
github.com/pelletier/go-toml v1.8.0 // indirect
|
||||
github.com/pquerna/cachecontrol v0.0.0-20200819021114-67c6ae64274f // indirect
|
||||
github.com/recws-org/recws v1.2.1
|
||||
github.com/sirupsen/logrus v1.4.2
|
||||
github.com/spf13/afero v1.3.4 // indirect
|
||||
github.com/spf13/cast v1.3.1 // indirect
|
||||
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/spf13/viper v1.7.1 // indirect
|
||||
github.com/stretchr/testify v1.6.1
|
||||
go.mongodb.org/mongo-driver v1.4.0 // indirect
|
||||
golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de // indirect
|
||||
golang.org/x/net v0.0.0-20200822124328-c89045814202 // indirect
|
||||
golang.org/x/sys v0.0.0-20200828194041-157a740278f4 // indirect
|
||||
golang.org/x/tools v0.0.0-20200828161849-5deb26317202 // indirect
|
||||
gopkg.in/ini.v1 v1.60.2 // indirect
|
||||
gopkg.in/square/go-jose.v2 v2.5.1 // indirect
|
||||
)
|
||||
1042
proxy/go.sum
Normal file
1042
proxy/go.sum
Normal file
File diff suppressed because it is too large
Load Diff
11
proxy/main.go
Normal file
11
proxy/main.go
Normal file
@ -0,0 +1,11 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/BeryJu/passbook/proxy/cmd"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func main() {
|
||||
log.SetLevel(log.DebugLevel)
|
||||
cmd.RunServer()
|
||||
}
|
||||
1039
proxy/pkg/proxy/oauthproxy.go
Normal file
1039
proxy/pkg/proxy/oauthproxy.go
Normal file
File diff suppressed because it is too large
Load Diff
187
proxy/pkg/proxy/templates.go
Normal file
187
proxy/pkg/proxy/templates.go
Normal file
@ -0,0 +1,187 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
)
|
||||
|
||||
func loadTemplates(dir string) *template.Template {
|
||||
if dir == "" {
|
||||
return getTemplates()
|
||||
}
|
||||
logger.Printf("using custom template directory %q", dir)
|
||||
funcMap := template.FuncMap{
|
||||
"ToUpper": strings.ToUpper,
|
||||
"ToLower": strings.ToLower,
|
||||
}
|
||||
t, err := template.New("").Funcs(funcMap).ParseFiles(path.Join(dir, "sign_in.html"), path.Join(dir, "error.html"))
|
||||
if err != nil {
|
||||
logger.Fatalf("failed parsing template %s", err)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func getTemplates() *template.Template {
|
||||
t, err := template.New("foo").Parse(`{{define "sign_in.html"}}
|
||||
<!DOCTYPE html>
|
||||
<html lang="en" charset="utf-8">
|
||||
<head>
|
||||
<title>Sign In</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">
|
||||
<style>
|
||||
body {
|
||||
font-family: "Helvetica Neue",Helvetica,Arial,sans-serif;
|
||||
font-size: 14px;
|
||||
line-height: 1.42857143;
|
||||
color: #333;
|
||||
background: #f0f0f0;
|
||||
}
|
||||
.signin {
|
||||
display:block;
|
||||
margin:20px auto;
|
||||
max-width:400px;
|
||||
background: #fff;
|
||||
border:1px solid #ccc;
|
||||
border-radius: 10px;
|
||||
padding: 20px;
|
||||
}
|
||||
.center {
|
||||
text-align:center;
|
||||
}
|
||||
.btn {
|
||||
color: #fff;
|
||||
background-color: #428bca;
|
||||
border: 1px solid #357ebd;
|
||||
-webkit-border-radius: 4;
|
||||
-moz-border-radius: 4;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
padding: 6px 12px;
|
||||
text-decoration: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.btn:hover {
|
||||
background-color: #3071a9;
|
||||
border-color: #285e8e;
|
||||
text-decoration: none;
|
||||
}
|
||||
label {
|
||||
display: inline-block;
|
||||
max-width: 100%;
|
||||
margin-bottom: 5px;
|
||||
font-weight: 700;
|
||||
}
|
||||
input {
|
||||
display: block;
|
||||
width: 100%;
|
||||
height: 34px;
|
||||
padding: 6px 12px;
|
||||
font-size: 14px;
|
||||
line-height: 1.42857143;
|
||||
color: #555;
|
||||
background-color: #fff;
|
||||
background-image: none;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
-webkit-box-shadow: inset 0 1px 1px rgba(0,0,0,.075);
|
||||
box-shadow: inset 0 1px 1px rgba(0,0,0,.075);
|
||||
-webkit-transition: border-color ease-in-out .15s,-webkit-box-shadow ease-in-out .15s;
|
||||
-o-transition: border-color ease-in-out .15s,box-shadow ease-in-out .15s;
|
||||
transition: border-color ease-in-out .15s,box-shadow ease-in-out .15s;
|
||||
margin:0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
footer {
|
||||
display:block;
|
||||
font-size:10px;
|
||||
color:#aaa;
|
||||
text-align:center;
|
||||
margin-bottom:10px;
|
||||
}
|
||||
footer a {
|
||||
display:inline-block;
|
||||
height:25px;
|
||||
line-height:25px;
|
||||
color:#aaa;
|
||||
text-decoration:underline;
|
||||
}
|
||||
footer a:hover {
|
||||
color:#aaa;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="signin center">
|
||||
<form method="GET" action="{{.ProxyPrefix}}/start">
|
||||
<input type="hidden" name="rd" value="{{.Redirect}}">
|
||||
{{ if .SignInMessage }}
|
||||
<p>{{.SignInMessage}}</p>
|
||||
{{ end}}
|
||||
<button type="submit" class="btn">Sign in with {{.ProviderName}}</button><br/>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
{{ if .CustomLogin }}
|
||||
<div class="signin">
|
||||
<form method="POST" action="{{.ProxyPrefix}}/sign_in">
|
||||
<input type="hidden" name="rd" value="{{.Redirect}}">
|
||||
<label for="username">Username:</label><input type="text" name="username" id="username" size="10"><br/>
|
||||
<label for="password">Password:</label><input type="password" name="password" id="password" size="10"><br/>
|
||||
<button type="submit" class="btn">Sign In</button>
|
||||
</form>
|
||||
</div>
|
||||
{{ end }}
|
||||
<script>
|
||||
if (window.location.hash) {
|
||||
(function() {
|
||||
var inputs = document.getElementsByName('rd');
|
||||
for (var i = 0; i < inputs.length; i++) {
|
||||
// Add hash, but make sure it is only added once
|
||||
var idx = inputs[i].value.indexOf('#');
|
||||
if (idx >= 0) {
|
||||
// Remove existing hash from URL
|
||||
inputs[i].value = inputs[i].value.substr(0, idx);
|
||||
}
|
||||
inputs[i].value += window.location.hash;
|
||||
}
|
||||
})();
|
||||
}
|
||||
</script>
|
||||
<footer>
|
||||
{{ if eq .Footer "-" }}
|
||||
{{ else if eq .Footer ""}}
|
||||
Secured with <a href="https://github.com/oauth2-proxy/oauth2-proxy#oauth2_proxy">OAuth2 Proxy</a> version {{.Version}}
|
||||
{{ else }}
|
||||
{{.Footer}}
|
||||
{{ end }}
|
||||
</footer>
|
||||
</body>
|
||||
</html>
|
||||
{{end}}`)
|
||||
if err != nil {
|
||||
logger.Fatalf("failed parsing template %s", err)
|
||||
}
|
||||
|
||||
t, err = t.Parse(`{{define "error.html"}}
|
||||
<!DOCTYPE html>
|
||||
<html lang="en" charset="utf-8">
|
||||
<head>
|
||||
<title>{{.Title}}</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">
|
||||
</head>
|
||||
<body>
|
||||
<h2>{{.Title}}</h2>
|
||||
<p>{{.Message}}</p>
|
||||
<hr>
|
||||
<p><a href="{{.ProxyPrefix}}/sign_in">Sign In</a></p>
|
||||
</body>
|
||||
</html>{{end}}`)
|
||||
if err != nil {
|
||||
logger.Fatalf("failed parsing template %s", err)
|
||||
}
|
||||
return t
|
||||
}
|
||||
62
proxy/pkg/proxy/templates_test.go
Normal file
62
proxy/pkg/proxy/templates_test.go
Normal file
@ -0,0 +1,62 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLoadTemplates(t *testing.T) {
|
||||
data := struct {
|
||||
TestString string
|
||||
}{
|
||||
TestString: "Testing",
|
||||
}
|
||||
|
||||
templates := loadTemplates("")
|
||||
assert.NotEqual(t, templates, nil)
|
||||
|
||||
var defaultSignin bytes.Buffer
|
||||
templates.ExecuteTemplate(&defaultSignin, "sign_in.html", data)
|
||||
assert.Equal(t, "\n<!DOCTYPE html>", defaultSignin.String()[0:16])
|
||||
|
||||
var defaultError bytes.Buffer
|
||||
templates.ExecuteTemplate(&defaultError, "error.html", data)
|
||||
assert.Equal(t, "\n<!DOCTYPE html>", defaultError.String()[0:16])
|
||||
|
||||
dir, err := ioutil.TempDir("", "templatetest")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
templateHTML := `{{.TestString}} {{.TestString | ToLower}} {{.TestString | ToUpper}}`
|
||||
signInFile := filepath.Join(dir, "sign_in.html")
|
||||
if err := ioutil.WriteFile(signInFile, []byte(templateHTML), 0666); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
errorFile := filepath.Join(dir, "error.html")
|
||||
if err := ioutil.WriteFile(errorFile, []byte(templateHTML), 0666); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
templates = loadTemplates(dir)
|
||||
assert.NotEqual(t, templates, nil)
|
||||
|
||||
var sitpl bytes.Buffer
|
||||
templates.ExecuteTemplate(&sitpl, "sign_in.html", data)
|
||||
assert.Equal(t, "Testing testing TESTING", sitpl.String())
|
||||
|
||||
var errtpl bytes.Buffer
|
||||
templates.ExecuteTemplate(&errtpl, "error.html", data)
|
||||
assert.Equal(t, "Testing testing TESTING", errtpl.String())
|
||||
}
|
||||
|
||||
func TestTemplatesCompile(t *testing.T) {
|
||||
templates := getTemplates()
|
||||
assert.NotEqual(t, templates, nil)
|
||||
}
|
||||
212
proxy/pkg/server/api.go
Normal file
212
proxy/pkg/server/api.go
Normal file
@ -0,0 +1,212 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/sha512"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/BeryJu/passbook/proxy/pkg/client"
|
||||
"github.com/BeryJu/passbook/proxy/pkg/client/outposts"
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/go-openapi/runtime"
|
||||
"github.com/recws-org/recws"
|
||||
|
||||
httptransport "github.com/go-openapi/runtime/client"
|
||||
"github.com/go-openapi/strfmt"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const ConfigLogLevel = "log_level"
|
||||
const ConfigErrorReportingEnabled = "error_reporting_enabled"
|
||||
const ConfigErrorReportingEnvironment = "error_reporting_environment"
|
||||
|
||||
// APIController main controller which connects to the passbook api via http and ws
|
||||
type APIController struct {
|
||||
client *client.Passbook
|
||||
auth runtime.ClientAuthInfoWriter
|
||||
token string
|
||||
|
||||
server *Server
|
||||
|
||||
commonOpts *options.Options
|
||||
|
||||
lastBundleHash string
|
||||
logger *log.Entry
|
||||
|
||||
wsConn recws.RecConn
|
||||
}
|
||||
|
||||
func getCommonOptions() *options.Options {
|
||||
commonOpts := options.NewOptions()
|
||||
commonOpts.Cookie.Name = "passbook_proxy"
|
||||
commonOpts.EmailDomains = []string{"*"}
|
||||
commonOpts.ProviderType = "oidc"
|
||||
commonOpts.ProxyPrefix = "/pbprox"
|
||||
commonOpts.SkipProviderButton = true
|
||||
commonOpts.Logging.SilencePing = true
|
||||
commonOpts.SetXAuthRequest = true
|
||||
commonOpts.SetAuthorization = true
|
||||
return commonOpts
|
||||
}
|
||||
|
||||
func doGlobalSetup(config map[string]interface{}) {
|
||||
switch config[ConfigLogLevel].(string) {
|
||||
case "debug":
|
||||
log.SetLevel(log.DebugLevel)
|
||||
case "info":
|
||||
log.SetLevel(log.InfoLevel)
|
||||
case "warning":
|
||||
log.SetLevel(log.WarnLevel)
|
||||
case "error":
|
||||
log.SetLevel(log.ErrorLevel)
|
||||
default:
|
||||
log.SetLevel(log.DebugLevel)
|
||||
}
|
||||
|
||||
var dsn string
|
||||
if config[ConfigErrorReportingEnabled].(bool) {
|
||||
dsn = "https://33cdbcb23f8b436dbe0ee06847410b67@sentry.beryju.org/3"
|
||||
log.Debug("Error reporting enabled")
|
||||
}
|
||||
|
||||
err := sentry.Init(sentry.ClientOptions{
|
||||
Dsn: dsn,
|
||||
Environment: config[ConfigErrorReportingEnvironment].(string),
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("sentry.Init: %s", err)
|
||||
}
|
||||
|
||||
defer sentry.Flush(2 * time.Second)
|
||||
}
|
||||
|
||||
func getTLSTransport() http.RoundTripper {
|
||||
_, set := os.LookupEnv("PASSBOOK_INSECURE")
|
||||
tlsTransport, err := httptransport.TLSTransport(httptransport.TLSClientOptions{
|
||||
InsecureSkipVerify: set,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return tlsTransport
|
||||
}
|
||||
|
||||
// NewAPIController initialise new API Controller instance from URL and API token
|
||||
func NewAPIController(pbURL url.URL, token string) *APIController {
|
||||
transport := httptransport.New(pbURL.Host, client.DefaultBasePath, []string{pbURL.Scheme})
|
||||
|
||||
transport.Transport = getTLSTransport()
|
||||
|
||||
// create the transport
|
||||
auth := httptransport.BasicAuth("", token)
|
||||
|
||||
// create the API client, with the transport
|
||||
apiClient := client.New(transport, strfmt.Default)
|
||||
|
||||
// Because we don't know the outpost UUID, we simply do a list and pick the first
|
||||
// The service account this token belongs to should only have access to a single outpost
|
||||
outposts, err := apiClient.Outposts.OutpostsOutpostsList(outposts.NewOutpostsOutpostsListParams(), auth)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
outpost := outposts.Payload.Results[0]
|
||||
doGlobalSetup(outpost.Config.(map[string]interface{}))
|
||||
|
||||
ac := &APIController{
|
||||
client: apiClient,
|
||||
auth: auth,
|
||||
token: token,
|
||||
|
||||
logger: log.WithField("component", "api-controller"),
|
||||
commonOpts: getCommonOptions(),
|
||||
server: NewServer(),
|
||||
|
||||
lastBundleHash: "",
|
||||
}
|
||||
ac.initWS(pbURL, outpost.Pk)
|
||||
return ac
|
||||
}
|
||||
|
||||
func (a *APIController) bundleProviders() ([]*providerBundle, error) {
|
||||
providers, err := a.client.Outposts.OutpostsProxyList(outposts.NewOutpostsProxyListParams(), a.auth)
|
||||
if err != nil {
|
||||
a.logger.WithError(err).Error("Failed to fetch providers")
|
||||
return nil, err
|
||||
}
|
||||
// Check provider hash to see if anything is changed
|
||||
hasher := sha512.New()
|
||||
bin, _ := providers.Payload.MarshalBinary()
|
||||
hash := hex.EncodeToString(hasher.Sum(bin))
|
||||
if hash == a.lastBundleHash {
|
||||
return nil, nil
|
||||
}
|
||||
a.lastBundleHash = hash
|
||||
|
||||
bundles := make([]*providerBundle, len(providers.Payload.Results))
|
||||
|
||||
for idx, provider := range providers.Payload.Results {
|
||||
externalHost, err := url.Parse(*provider.ExternalHost)
|
||||
if err != nil {
|
||||
log.WithError(err).Warning("Failed to parse URL, skipping provider")
|
||||
}
|
||||
bundles[idx] = &providerBundle{
|
||||
a: a,
|
||||
Host: externalHost.Hostname(),
|
||||
}
|
||||
bundles[idx].Build(provider)
|
||||
}
|
||||
return bundles, nil
|
||||
}
|
||||
|
||||
func (a *APIController) updateHTTPServer(bundles []*providerBundle) {
|
||||
newMap := make(map[string]*providerBundle)
|
||||
for _, bundle := range bundles {
|
||||
newMap[bundle.Host] = bundle
|
||||
}
|
||||
a.logger.Debug("Swapped maps")
|
||||
a.server.Handlers = newMap
|
||||
}
|
||||
|
||||
// UpdateIfRequired Updates the HTTP Server config if required, automatically swaps the handlers
|
||||
func (a *APIController) UpdateIfRequired() error {
|
||||
bundles, err := a.bundleProviders()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if bundles == nil {
|
||||
a.logger.Debug("Providers have not changed, not updating")
|
||||
return nil
|
||||
}
|
||||
a.updateHTTPServer(bundles)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start Starts all handlers, non-blocking
|
||||
func (a *APIController) Start() error {
|
||||
err := a.UpdateIfRequired()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
a.logger.Debug("Starting HTTP Server...")
|
||||
a.server.ServeHTTP()
|
||||
}()
|
||||
go func() {
|
||||
a.logger.Debug("Starting HTTPs Server...")
|
||||
a.server.ServeHTTPS()
|
||||
}()
|
||||
go func() {
|
||||
a.logger.Debug("Starting WS Handler...")
|
||||
a.startWSHandler()
|
||||
}()
|
||||
go func() {
|
||||
a.logger.Debug("Starting WS Health notifier...")
|
||||
a.startWSHealth()
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
123
proxy/pkg/server/api_bundle.go
Normal file
123
proxy/pkg/server/api_bundle.go
Normal file
@ -0,0 +1,123 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"github.com/BeryJu/passbook/proxy/pkg/client/crypto"
|
||||
"github.com/BeryJu/passbook/proxy/pkg/models"
|
||||
"github.com/BeryJu/passbook/proxy/pkg/proxy"
|
||||
"github.com/jinzhu/copier"
|
||||
"github.com/justinas/alice"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware"
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/validation"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type providerBundle struct {
|
||||
http.Handler
|
||||
|
||||
a *APIController
|
||||
proxy *proxy.OAuthProxy
|
||||
Host string
|
||||
|
||||
cert *tls.Certificate
|
||||
}
|
||||
|
||||
func (pb *providerBundle) prepareOpts(provider *models.ProxyOutpostConfig) *options.Options {
|
||||
externalHost, err := url.Parse(*provider.ExternalHost)
|
||||
if err != nil {
|
||||
log.WithError(err).Warning("Failed to parse URL, skipping provider")
|
||||
return nil
|
||||
}
|
||||
providerOpts := &options.Options{}
|
||||
copier.Copy(&providerOpts, &pb.a.commonOpts)
|
||||
providerOpts.ClientID = provider.ClientID
|
||||
providerOpts.ClientSecret = provider.ClientSecret
|
||||
|
||||
providerOpts.Cookie.Secret = provider.CookieSecret
|
||||
providerOpts.Cookie.Secure = externalHost.Scheme == "https"
|
||||
|
||||
providerOpts.SkipOIDCDiscovery = true
|
||||
providerOpts.OIDCIssuerURL = *provider.OidcConfiguration.Issuer
|
||||
providerOpts.LoginURL = *provider.OidcConfiguration.AuthorizationEndpoint
|
||||
providerOpts.RedeemURL = *provider.OidcConfiguration.TokenEndpoint
|
||||
providerOpts.OIDCJwksURL = *provider.OidcConfiguration.JwksURI
|
||||
providerOpts.ProfileURL = *provider.OidcConfiguration.UserinfoEndpoint
|
||||
|
||||
providerOpts.UpstreamServers = []options.Upstream{
|
||||
{
|
||||
ID: "default",
|
||||
URI: *provider.InternalHost,
|
||||
Path: "/",
|
||||
},
|
||||
}
|
||||
|
||||
if provider.Certificate != nil {
|
||||
pb.a.logger.WithField("provider", provider.ClientID).Debug("Enabling TLS")
|
||||
cert, err := pb.a.client.Crypto.CryptoCertificatekeypairsRead(&crypto.CryptoCertificatekeypairsReadParams{
|
||||
Context: context.Background(),
|
||||
KpUUID: *provider.Certificate,
|
||||
}, pb.a.auth)
|
||||
if err != nil {
|
||||
pb.a.logger.WithField("provider", provider.ClientID).WithError(err).Warning("Failed to fetch certificate")
|
||||
return providerOpts
|
||||
}
|
||||
x509cert, err := tls.X509KeyPair([]byte(*cert.Payload.CertificateData), []byte(cert.Payload.KeyData))
|
||||
if err != nil {
|
||||
pb.a.logger.WithField("provider", provider.ClientID).WithError(err).Warning("Failed to parse certificate")
|
||||
return providerOpts
|
||||
}
|
||||
pb.cert = &x509cert
|
||||
pb.a.logger.WithField("provider", provider.ClientID).WithField("certificate-key-pair", *cert.Payload.Name).Debug("Loaded certificates")
|
||||
}
|
||||
return providerOpts
|
||||
}
|
||||
|
||||
func (pb *providerBundle) Build(provider *models.ProxyOutpostConfig) {
|
||||
opts := pb.prepareOpts(provider)
|
||||
|
||||
chain := alice.New()
|
||||
|
||||
if opts.ForceHTTPS {
|
||||
_, httpsPort, err := net.SplitHostPort(opts.HTTPSAddress)
|
||||
if err != nil {
|
||||
log.Fatalf("FATAL: invalid HTTPS address %q: %v", opts.HTTPAddress, err)
|
||||
}
|
||||
chain = chain.Append(middleware.NewRedirectToHTTPS(httpsPort))
|
||||
}
|
||||
|
||||
healthCheckPaths := []string{opts.PingPath}
|
||||
healthCheckUserAgents := []string{opts.PingUserAgent}
|
||||
if opts.GCPHealthChecks {
|
||||
healthCheckPaths = append(healthCheckPaths, "/liveness_check", "/readiness_check")
|
||||
healthCheckUserAgents = append(healthCheckUserAgents, "GoogleHC/1.0")
|
||||
}
|
||||
|
||||
// To silence logging of health checks, register the health check handler before
|
||||
// the logging handler
|
||||
if opts.Logging.SilencePing {
|
||||
chain = chain.Append(middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), LoggingHandler)
|
||||
} else {
|
||||
chain = chain.Append(LoggingHandler, middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents))
|
||||
}
|
||||
|
||||
err := validation.Validate(opts)
|
||||
if err != nil {
|
||||
log.Printf("%s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
oauthproxy, err := proxy.NewOAuthProxy(opts)
|
||||
if err != nil {
|
||||
log.Errorf("ERROR: Failed to initialise OAuth2 Proxy: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
pb.proxy = oauthproxy
|
||||
pb.Handler = chain.Then(oauthproxy)
|
||||
}
|
||||
85
proxy/pkg/server/api_ws.go
Normal file
85
proxy/pkg/server/api_ws.go
Normal file
@ -0,0 +1,85 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-openapi/strfmt"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/recws-org/recws"
|
||||
)
|
||||
|
||||
func (ac *APIController) initWS(pbURL url.URL, outpostUUID strfmt.UUID) {
|
||||
pathTemplate := "%s://%s/ws/outpost/%s/"
|
||||
scheme := strings.ReplaceAll(pbURL.Scheme, "http", "ws")
|
||||
|
||||
header := http.Header{
|
||||
"Authorization": []string{ac.token},
|
||||
}
|
||||
|
||||
_, set := os.LookupEnv("PASSBOOK_INSECURE")
|
||||
|
||||
ws := recws.RecConn{
|
||||
// KeepAliveTimeout: 10 * time.Second,
|
||||
NonVerbose: true,
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: set,
|
||||
},
|
||||
}
|
||||
ws.Dial(fmt.Sprintf(pathTemplate, scheme, pbURL.Host, outpostUUID.String()), header)
|
||||
|
||||
ac.logger.WithField("outpost", outpostUUID.String()).Debug("connecting to passbook")
|
||||
|
||||
ac.wsConn = ws
|
||||
}
|
||||
|
||||
// Shutdown Gracefully stops all workers, disconnects from websocket
|
||||
func (ac *APIController) Shutdown() {
|
||||
// Cleanly close the connection by sending a close message and then
|
||||
// waiting (with timeout) for the server to close the connection.
|
||||
err := ac.wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
if err != nil {
|
||||
ac.logger.Println("write close:", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ac *APIController) startWSHandler() {
|
||||
for {
|
||||
var wsMsg websocketMessage
|
||||
err := ac.wsConn.ReadJSON(&wsMsg)
|
||||
if err != nil {
|
||||
ac.logger.Println("read:", err)
|
||||
return
|
||||
}
|
||||
if wsMsg.Instruction != WebsocketInstructionAck {
|
||||
ac.logger.Debugf("%+v\n", wsMsg)
|
||||
}
|
||||
if wsMsg.Instruction == WebsocketInstructionTriggerUpdate {
|
||||
err := ac.UpdateIfRequired()
|
||||
if err != nil {
|
||||
ac.logger.WithError(err).Debug("Failed to update")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ac *APIController) startWSHealth() {
|
||||
for ; true; <-time.Tick(time.Second * 10) {
|
||||
aliveMsg := websocketMessage{
|
||||
Instruction: WebsocketInstructionHello,
|
||||
Args: make(map[string]interface{}),
|
||||
}
|
||||
err := ac.wsConn.WriteJSON(aliveMsg)
|
||||
if err != nil {
|
||||
ac.logger.Println("write:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
17
proxy/pkg/server/api_ws_msg.go
Normal file
17
proxy/pkg/server/api_ws_msg.go
Normal file
@ -0,0 +1,17 @@
|
||||
package server
|
||||
|
||||
type websocketInstruction int
|
||||
|
||||
const (
|
||||
// WebsocketInstructionAck Code used to acknowledge a previous message
|
||||
WebsocketInstructionAck websocketInstruction = 0
|
||||
// WebsocketInstructionHello Code used to send a healthcheck keepalive
|
||||
WebsocketInstructionHello websocketInstruction = 1
|
||||
// WebsocketInstructionTriggerUpdate Code received to trigger a config update
|
||||
WebsocketInstructionTriggerUpdate websocketInstruction = 2
|
||||
)
|
||||
|
||||
type websocketMessage struct {
|
||||
Instruction websocketInstruction `json:"instruction"`
|
||||
Args map[string]interface{} `json:"args"`
|
||||
}
|
||||
63
proxy/pkg/server/cert.go
Normal file
63
proxy/pkg/server/cert.go
Normal file
@ -0,0 +1,63 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func generateSelfSignedCert() (tls.Certificate, error) {
|
||||
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to generate private key: %v", err)
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
|
||||
keyUsage := x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment
|
||||
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(365 * 24 * time.Hour)
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to generate serial number: %v", err)
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"passbook"},
|
||||
CommonName: "passbook Proxy default certificate",
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
|
||||
KeyUsage: keyUsage,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
template.DNSNames = []string{"*"}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
log.Warning(err)
|
||||
}
|
||||
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
|
||||
if err != nil {
|
||||
log.Warning(err)
|
||||
}
|
||||
privPemByes := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})
|
||||
return tls.X509KeyPair(pemBytes, privPemByes)
|
||||
}
|
||||
123
proxy/pkg/server/middleware.go
Normal file
123
proxy/pkg/server/middleware.go
Normal file
@ -0,0 +1,123 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP status
|
||||
// code and body size
|
||||
type responseLogger struct {
|
||||
w http.ResponseWriter
|
||||
status int
|
||||
size int
|
||||
upstream string
|
||||
authInfo string
|
||||
}
|
||||
|
||||
// Header returns the ResponseWriter's Header
|
||||
func (l *responseLogger) Header() http.Header {
|
||||
return l.w.Header()
|
||||
}
|
||||
|
||||
// Support Websocket
|
||||
func (l *responseLogger) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
|
||||
if hj, ok := l.w.(http.Hijacker); ok {
|
||||
return hj.Hijack()
|
||||
}
|
||||
return nil, nil, errors.New("http.Hijacker is not available on writer")
|
||||
}
|
||||
|
||||
// ExtractGAPMetadata extracts and removes GAP headers from the ResponseWriter's
|
||||
// Header
|
||||
func (l *responseLogger) ExtractGAPMetadata() {
|
||||
upstream := l.w.Header().Get("GAP-Upstream-Address")
|
||||
if upstream != "" {
|
||||
l.upstream = upstream
|
||||
l.w.Header().Del("GAP-Upstream-Address")
|
||||
}
|
||||
authInfo := l.w.Header().Get("GAP-Auth")
|
||||
if authInfo != "" {
|
||||
l.authInfo = authInfo
|
||||
l.w.Header().Del("GAP-Auth")
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes the response using the ResponseWriter
|
||||
func (l *responseLogger) Write(b []byte) (int, error) {
|
||||
if l.status == 0 {
|
||||
// The status will be StatusOK if WriteHeader has not been called yet
|
||||
l.status = http.StatusOK
|
||||
}
|
||||
l.ExtractGAPMetadata()
|
||||
size, err := l.w.Write(b)
|
||||
l.size += size
|
||||
return size, err
|
||||
}
|
||||
|
||||
// WriteHeader writes the status code for the Response
|
||||
func (l *responseLogger) WriteHeader(s int) {
|
||||
l.ExtractGAPMetadata()
|
||||
l.w.WriteHeader(s)
|
||||
l.status = s
|
||||
}
|
||||
|
||||
// Status returns the response status code
|
||||
func (l *responseLogger) Status() int {
|
||||
return l.status
|
||||
}
|
||||
|
||||
// Size returns the response size
|
||||
func (l *responseLogger) Size() int {
|
||||
return l.size
|
||||
}
|
||||
|
||||
// Flush sends any buffered data to the client
|
||||
func (l *responseLogger) Flush() {
|
||||
if flusher, ok := l.w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// loggingHandler is the http.Handler implementation for LoggingHandler
|
||||
type loggingHandler struct {
|
||||
handler http.Handler
|
||||
logger *log.Entry
|
||||
}
|
||||
|
||||
// LoggingHandler provides an http.Handler which logs requests to the HTTP server
|
||||
func LoggingHandler(h http.Handler) http.Handler {
|
||||
return loggingHandler{
|
||||
handler: h,
|
||||
logger: log.WithField("component", "http-server"),
|
||||
}
|
||||
}
|
||||
|
||||
func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
t := time.Now()
|
||||
url := *req.URL
|
||||
responseLogger := &responseLogger{w: w}
|
||||
h.handler.ServeHTTP(responseLogger, req)
|
||||
duration := float64(time.Since(t)) / float64(time.Second)
|
||||
h.logger.WithFields(log.Fields{
|
||||
"Client": req.RemoteAddr,
|
||||
"Host": req.Host,
|
||||
"Protocol": req.Proto,
|
||||
"RequestDuration": fmt.Sprintf("%0.3f", duration),
|
||||
"RequestMethod": req.Method,
|
||||
"ResponseSize": responseLogger.Size(),
|
||||
"StatusCode": responseLogger.Status(),
|
||||
"Timestamp": logger.FormatTimestamp(t),
|
||||
"Upstream": responseLogger.upstream,
|
||||
"UserAgent": req.UserAgent(),
|
||||
"Username": responseLogger.authInfo,
|
||||
}).Info(url.RequestURI())
|
||||
// logger.PrintReq(responseLogger.authInfo, responseLogger.upstream, req, url, t, , )
|
||||
}
|
||||
152
proxy/pkg/server/server.go
Normal file
152
proxy/pkg/server/server.go
Normal file
@ -0,0 +1,152 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
sentryhttp "github.com/getsentry/sentry-go/http"
|
||||
)
|
||||
|
||||
// Server represents an HTTP server
|
||||
type Server struct {
|
||||
Handlers map[string]*providerBundle
|
||||
|
||||
stop chan struct{} // channel for waiting shutdown
|
||||
logger *log.Entry
|
||||
|
||||
defaultCert tls.Certificate
|
||||
}
|
||||
|
||||
// NewServer initialise a new HTTP Server
|
||||
func NewServer() *Server {
|
||||
defaultCert, err := generateSelfSignedCert()
|
||||
if err != nil {
|
||||
log.Warning(err)
|
||||
}
|
||||
return &Server{
|
||||
Handlers: make(map[string]*providerBundle),
|
||||
logger: log.WithField("component", "http-server"),
|
||||
defaultCert: defaultCert,
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP constructs a net.Listener and starts handling HTTP requests
|
||||
func (s *Server) ServeHTTP() {
|
||||
// TODO: make this a setting
|
||||
listenAddress := "localhost:4180"
|
||||
listener, err := net.Listen("tcp", listenAddress)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("FATAL: listen (%s) failed - %s", listenAddress, err)
|
||||
}
|
||||
s.logger.Printf("listening on %s", listener.Addr())
|
||||
s.serve(listener)
|
||||
s.logger.Printf("closing %s", listener.Addr())
|
||||
}
|
||||
|
||||
func (s *Server) getCertificates(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
handler, ok := s.Handlers[info.ServerName]
|
||||
if !ok {
|
||||
s.logger.WithField("server-name", info.ServerName).Debug("Handler does not exist")
|
||||
return &s.defaultCert, nil
|
||||
}
|
||||
if handler.cert == nil {
|
||||
s.logger.WithField("server-name", info.ServerName).Debug("Handler does not have a certificate")
|
||||
return &s.defaultCert, nil
|
||||
}
|
||||
return handler.cert, nil
|
||||
}
|
||||
|
||||
// ServeHTTPS constructs a net.Listener and starts handling HTTPS requests
|
||||
func (s *Server) ServeHTTPS() {
|
||||
// TODO: make this a setting
|
||||
listenAddress := "localhost:4443"
|
||||
config := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MaxVersion: tls.VersionTLS12,
|
||||
GetCertificate: s.getCertificates,
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", listenAddress)
|
||||
if err != nil {
|
||||
s.logger.Fatalf("FATAL: listen (%s) failed - %s", listenAddress, err)
|
||||
}
|
||||
s.logger.Printf("listening on %s", ln.Addr())
|
||||
|
||||
tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
|
||||
s.serve(tlsListener)
|
||||
s.logger.Printf("closing %s", tlsListener.Addr())
|
||||
}
|
||||
|
||||
func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
|
||||
handler, ok := s.Handlers[r.Host]
|
||||
if !ok {
|
||||
// If we only have one handler, host name switching doesn't matter
|
||||
if len(s.Handlers) == 1 {
|
||||
for k := range s.Handlers {
|
||||
s.Handlers[k].ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
s.logger.WithField("host", r.Host).Debug("Host header does not match any we know of")
|
||||
s.logger.Printf("%v+\n", s.Handlers)
|
||||
w.WriteHeader(400)
|
||||
return
|
||||
}
|
||||
s.logger.WithField("host", r.Host).Debug("passing request from host head")
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (s *Server) serve(listener net.Listener) {
|
||||
sentryHandler := sentryhttp.New(sentryhttp.Options{})
|
||||
|
||||
srv := &http.Server{Handler: sentryHandler.HandleFunc(s.handler)}
|
||||
|
||||
// See https://golang.org/pkg/net/http/#Server.Shutdown
|
||||
idleConnsClosed := make(chan struct{})
|
||||
go func() {
|
||||
<-s.stop // wait notification for stopping server
|
||||
|
||||
// We received an interrupt signal, shut down.
|
||||
if err := srv.Shutdown(context.Background()); err != nil {
|
||||
// Error from closing listeners, or context timeout:
|
||||
s.logger.Printf("HTTP server Shutdown: %v", err)
|
||||
}
|
||||
close(idleConnsClosed)
|
||||
}()
|
||||
|
||||
err := srv.Serve(listener)
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.logger.Errorf("ERROR: http.Serve() - %s", err)
|
||||
}
|
||||
<-idleConnsClosed
|
||||
}
|
||||
|
||||
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
|
||||
// connections. It's used by ListenAndServe and ListenAndServeTLS so
|
||||
// dead TCP connections (e.g. closing laptop mid-download) eventually
|
||||
// go away.
|
||||
type tcpKeepAliveListener struct {
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
|
||||
tc, err := ln.AcceptTCP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = tc.SetKeepAlive(true)
|
||||
if err != nil {
|
||||
log.Printf("Error setting Keep-Alive: %v", err)
|
||||
}
|
||||
err = tc.SetKeepAlivePeriod(3 * time.Minute)
|
||||
if err != nil {
|
||||
log.Printf("Error setting Keep-Alive period: %v", err)
|
||||
}
|
||||
return tc, nil
|
||||
}
|
||||
Reference in New Issue
Block a user