providers/proxy: improve SLO by backchannel logging out sessions (#7099)
* outposts: add support for provider-specific websocket messages Signed-off-by: Jens Langhammer <jens@goauthentik.io> * providers/proxy: add custom signal on logout to logout in provider Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -22,6 +22,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type WSHandler func(ctx context.Context, args map[string]interface{})
|
||||
|
||||
const ConfigLogLevel = "log_level"
|
||||
|
||||
// APIController main controller which connects to the authentik api via http and ws
|
||||
@ -42,6 +44,7 @@ type APIController struct {
|
||||
lastWsReconnect time.Time
|
||||
wsIsReconnecting bool
|
||||
wsBackoffMultiplier int
|
||||
wsHandlers []WSHandler
|
||||
refreshHandlers []func()
|
||||
|
||||
instanceUUID uuid.UUID
|
||||
@ -106,6 +109,7 @@ func NewAPIController(akURL url.URL, token string) *APIController {
|
||||
reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
|
||||
instanceUUID: uuid.New(),
|
||||
Outpost: outpost,
|
||||
wsHandlers: []WSHandler{},
|
||||
wsBackoffMultiplier: 1,
|
||||
refreshHandlers: make([]func(), 0),
|
||||
}
|
||||
@ -156,6 +160,10 @@ func (a *APIController) AddRefreshHandler(handler func()) {
|
||||
a.refreshHandlers = append(a.refreshHandlers, handler)
|
||||
}
|
||||
|
||||
func (a *APIController) AddWSHandler(handler WSHandler) {
|
||||
a.wsHandlers = append(a.wsHandlers, handler)
|
||||
}
|
||||
|
||||
func (a *APIController) OnRefresh() error {
|
||||
// Because we don't know the outpost UUID, we simply do a list and pick the first
|
||||
// The service account this token belongs to should only have access to a single outpost
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package ak
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -145,6 +146,10 @@ func (ac *APIController) startWSHandler() {
|
||||
"build": constants.BUILD("tagged"),
|
||||
}).SetToCurrentTime()
|
||||
}
|
||||
} else if wsMsg.Instruction == WebsocketInstructionProviderSpecific {
|
||||
for _, h := range ac.wsHandlers {
|
||||
h(context.Background(), wsMsg.Args)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -9,6 +9,8 @@ const (
|
||||
WebsocketInstructionHello websocketInstruction = 1
|
||||
// WebsocketInstructionTriggerUpdate Code received to trigger a config update
|
||||
WebsocketInstructionTriggerUpdate websocketInstruction = 2
|
||||
// WebsocketInstructionProviderSpecific Code received to trigger some provider specific function
|
||||
WebsocketInstructionProviderSpecific websocketInstruction = 3
|
||||
)
|
||||
|
||||
type websocketMessage struct {
|
||||
|
||||
@ -280,7 +280,9 @@ func (a *Application) handleSignOut(rw http.ResponseWriter, r *http.Request) {
|
||||
"id_token_hint": []string{cc.RawToken},
|
||||
}
|
||||
redirect += "?" + uv.Encode()
|
||||
err = a.Logout(r.Context(), cc.Sub)
|
||||
err = a.Logout(r.Context(), func(c Claims) bool {
|
||||
return c.Sub == cc.Sub
|
||||
})
|
||||
if err != nil {
|
||||
a.log.WithError(err).Warning("failed to logout of other sessions")
|
||||
}
|
||||
|
||||
@ -11,10 +11,11 @@ type Claims struct {
|
||||
Exp int `json:"exp"`
|
||||
Email string `json:"email"`
|
||||
Verified bool `json:"email_verified"`
|
||||
Proxy *ProxyClaims `json:"ak_proxy"`
|
||||
Name string `json:"name"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Groups []string `json:"groups"`
|
||||
Sid string `json:"sid"`
|
||||
Proxy *ProxyClaims `json:"ak_proxy"`
|
||||
|
||||
RawToken string
|
||||
}
|
||||
|
||||
@ -88,7 +88,7 @@ func (a *Application) getAllCodecs() []securecookie.Codec {
|
||||
return cs
|
||||
}
|
||||
|
||||
func (a *Application) Logout(ctx context.Context, sub string) error {
|
||||
func (a *Application) Logout(ctx context.Context, filter func(c Claims) bool) error {
|
||||
if _, ok := a.sessions.(*sessions.FilesystemStore); ok {
|
||||
files, err := os.ReadDir(os.TempDir())
|
||||
if err != nil {
|
||||
@ -118,7 +118,7 @@ func (a *Application) Logout(ctx context.Context, sub string) error {
|
||||
continue
|
||||
}
|
||||
claims := s.Values[constants.SessionClaims].(Claims)
|
||||
if claims.Sub == sub {
|
||||
if filter(claims) {
|
||||
a.log.WithField("path", fullPath).Trace("deleting session")
|
||||
err := os.Remove(fullPath)
|
||||
if err != nil {
|
||||
@ -153,7 +153,7 @@ func (a *Application) Logout(ctx context.Context, sub string) error {
|
||||
continue
|
||||
}
|
||||
claims := c.(Claims)
|
||||
if claims.Sub == sub {
|
||||
if filter(claims) {
|
||||
a.log.WithField("key", key).Trace("deleting session")
|
||||
_, err := client.Del(ctx, key).Result()
|
||||
if err != nil {
|
||||
|
||||
@ -65,6 +65,7 @@ func NewProxyServer(ac *ak.APIController) *ProxyServer {
|
||||
globalMux.PathPrefix("/outpost.goauthentik.io/static").HandlerFunc(s.HandleStatic)
|
||||
globalMux.Path("/outpost.goauthentik.io/ping").HandlerFunc(sentryutils.SentryNoSample(s.HandlePing))
|
||||
rootMux.PathPrefix("/").HandlerFunc(s.Handle)
|
||||
ac.AddWSHandler(s.handleWSMessage)
|
||||
return s
|
||||
}
|
||||
|
||||
|
||||
49
internal/outpost/proxyv2/ws.go
Normal file
49
internal/outpost/proxyv2/ws.go
Normal file
@ -0,0 +1,49 @@
|
||||
package proxyv2
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"goauthentik.io/internal/outpost/proxyv2/application"
|
||||
)
|
||||
|
||||
type WSProviderSubType string
|
||||
|
||||
const (
|
||||
WSProviderSubTypeLogout WSProviderSubType = "logout"
|
||||
)
|
||||
|
||||
type WSProviderMsg struct {
|
||||
SubType WSProviderSubType `mapstructure:"sub_type"`
|
||||
SessionID string `mapstructure:"session_id"`
|
||||
}
|
||||
|
||||
func ParseWSProvider(args map[string]interface{}) (*WSProviderMsg, error) {
|
||||
msg := &WSProviderMsg{}
|
||||
err := mapstructure.Decode(args, &msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (ps *ProxyServer) handleWSMessage(ctx context.Context, args map[string]interface{}) {
|
||||
msg, err := ParseWSProvider(args)
|
||||
if err != nil {
|
||||
ps.log.WithError(err).Warning("invalid provider-specific ws message")
|
||||
return
|
||||
}
|
||||
switch msg.SubType {
|
||||
case WSProviderSubTypeLogout:
|
||||
for _, p := range ps.apps {
|
||||
err := p.Logout(ctx, func(c application.Claims) bool {
|
||||
return c.Sid == msg.SessionID
|
||||
})
|
||||
if err != nil {
|
||||
ps.log.WithField("provider", p.Host).WithError(err).Warning("failed to logout")
|
||||
}
|
||||
}
|
||||
default:
|
||||
ps.log.WithField("sub_type", msg.SubType).Warning("invalid sub_type")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user