providers/proxy: improve SLO by backchannel logging out sessions (#7099)
* outposts: add support for provider-specific websocket messages Signed-off-by: Jens Langhammer <jens@goauthentik.io> * providers/proxy: add custom signal on logout to logout in provider Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		| @ -27,6 +27,9 @@ class WebsocketMessageInstruction(IntEnum): | |||||||
|     # Message sent by us to trigger an Update |     # Message sent by us to trigger an Update | ||||||
|     TRIGGER_UPDATE = 2 |     TRIGGER_UPDATE = 2 | ||||||
|  |  | ||||||
|  |     # Provider specific message | ||||||
|  |     PROVIDER_SPECIFIC = 3 | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass(slots=True) | @dataclass(slots=True) | ||||||
| class WebsocketMessage: | class WebsocketMessage: | ||||||
| @ -131,3 +134,14 @@ class OutpostConsumer(AuthJsonConsumer): | |||||||
|         self.send_json( |         self.send_json( | ||||||
|             asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)) |             asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def event_provider_specific(self, event): | ||||||
|  |         """Event handler which can be called by provider-specific | ||||||
|  |         implementations to send specific messages to the outpost""" | ||||||
|  |         self.send_json( | ||||||
|  |             asdict( | ||||||
|  |                 WebsocketMessage( | ||||||
|  |                     instruction=WebsocketMessageInstruction.PROVIDER_SPECIFIC, args=event | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  | |||||||
| @ -5,7 +5,6 @@ from socket import gethostname | |||||||
| from typing import Any, Optional | from typing import Any, Optional | ||||||
| from urllib.parse import urlparse | from urllib.parse import urlparse | ||||||
|  |  | ||||||
| import yaml |  | ||||||
| from asgiref.sync import async_to_sync | from asgiref.sync import async_to_sync | ||||||
| from channels.layers import get_channel_layer | from channels.layers import get_channel_layer | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| @ -16,6 +15,7 @@ from docker.constants import DEFAULT_UNIX_SOCKET | |||||||
| from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME | from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME | ||||||
| from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION | from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  | from yaml import safe_load | ||||||
|  |  | ||||||
| from authentik.events.monitored_tasks import ( | from authentik.events.monitored_tasks import ( | ||||||
|     MonitoredTask, |     MonitoredTask, | ||||||
| @ -279,7 +279,7 @@ def outpost_connection_discovery(self: MonitoredTask): | |||||||
|             with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig: |             with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig: | ||||||
|                 KubernetesServiceConnection.objects.create( |                 KubernetesServiceConnection.objects.create( | ||||||
|                     name=kubeconfig_local_name, |                     name=kubeconfig_local_name, | ||||||
|                     kubeconfig=yaml.safe_load(_kubeconfig), |                     kubeconfig=safe_load(_kubeconfig), | ||||||
|                 ) |                 ) | ||||||
|     unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path |     unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path | ||||||
|     socket = Path(unix_socket_path) |     socket = Path(unix_socket_path) | ||||||
|  | |||||||
| @ -9,3 +9,7 @@ class AuthentikProviderProxyConfig(ManagedAppConfig): | |||||||
|     label = "authentik_providers_proxy" |     label = "authentik_providers_proxy" | ||||||
|     verbose_name = "authentik Providers.Proxy" |     verbose_name = "authentik Providers.Proxy" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|  |     def reconcile_load_providers_proxy_signals(self): | ||||||
|  |         """Load proxy signals""" | ||||||
|  |         self.import_module("authentik.providers.proxy.signals") | ||||||
|  | |||||||
							
								
								
									
										20
									
								
								authentik/providers/proxy/signals.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								authentik/providers/proxy/signals.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,20 @@ | |||||||
|  | """Proxy provider signals""" | ||||||
|  | from django.contrib.auth.signals import user_logged_out | ||||||
|  | from django.db.models.signals import pre_delete | ||||||
|  | from django.dispatch import receiver | ||||||
|  | from django.http import HttpRequest | ||||||
|  |  | ||||||
|  | from authentik.core.models import AuthenticatedSession, User | ||||||
|  | from authentik.providers.proxy.tasks import proxy_on_logout | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(user_logged_out) | ||||||
|  | def logout_proxy_revoke_direct(sender: type[User], request: HttpRequest, **_): | ||||||
|  |     """Catch logout by direct logout and forward to proxy providers""" | ||||||
|  |     proxy_on_logout.delay(request.session.session_key) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(pre_delete, sender=AuthenticatedSession) | ||||||
|  | def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | ||||||
|  |     """Catch logout by expiring sessions being deleted""" | ||||||
|  |     proxy_on_logout.delay(instance.session_key) | ||||||
| @ -1,6 +1,9 @@ | |||||||
| """proxy provider tasks""" | """proxy provider tasks""" | ||||||
|  | from asgiref.sync import async_to_sync | ||||||
|  | from channels.layers import get_channel_layer | ||||||
| from django.db import DatabaseError, InternalError, ProgrammingError | from django.db import DatabaseError, InternalError, ProgrammingError | ||||||
|  |  | ||||||
|  | from authentik.outposts.models import Outpost, OutpostState, OutpostType | ||||||
| from authentik.providers.proxy.models import ProxyProvider | from authentik.providers.proxy.models import ProxyProvider | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| @ -13,3 +16,20 @@ def proxy_set_defaults(): | |||||||
|     for provider in ProxyProvider.objects.all(): |     for provider in ProxyProvider.objects.all(): | ||||||
|         provider.set_oauth_defaults() |         provider.set_oauth_defaults() | ||||||
|         provider.save() |         provider.save() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @CELERY_APP.task() | ||||||
|  | def proxy_on_logout(session_id: str): | ||||||
|  |     """Update outpost instances connected to a single outpost""" | ||||||
|  |     layer = get_channel_layer() | ||||||
|  |     for outpost in Outpost.objects.filter(type=OutpostType.PROXY): | ||||||
|  |         for state in OutpostState.for_outpost(outpost): | ||||||
|  |             for channel in state.channel_ids: | ||||||
|  |                 async_to_sync(layer.send)( | ||||||
|  |                     channel, | ||||||
|  |                     { | ||||||
|  |                         "type": "event.provider.specific", | ||||||
|  |                         "sub_type": "logout", | ||||||
|  |                         "session_id": session_id, | ||||||
|  |                     }, | ||||||
|  |                 ) | ||||||
|  | |||||||
| @ -15,6 +15,7 @@ entries: | |||||||
|         # This mapping is used by the authentik proxy. It passes extra user attributes, |         # This mapping is used by the authentik proxy. It passes extra user attributes, | ||||||
|         # which are used for example for the HTTP-Basic Authentication mapping. |         # which are used for example for the HTTP-Basic Authentication mapping. | ||||||
|         return { |         return { | ||||||
|  |             "sid": request.http_request.session.session_key, | ||||||
|             "ak_proxy": { |             "ak_proxy": { | ||||||
|                 "user_attributes": request.user.group_attributes(request), |                 "user_attributes": request.user.group_attributes(request), | ||||||
|                 "is_superuser": request.user.is_superuser, |                 "is_superuser": request.user.is_superuser, | ||||||
|  | |||||||
| @ -22,6 +22,8 @@ import ( | |||||||
| 	log "github.com/sirupsen/logrus" | 	log "github.com/sirupsen/logrus" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | type WSHandler func(ctx context.Context, args map[string]interface{}) | ||||||
|  |  | ||||||
| const ConfigLogLevel = "log_level" | const ConfigLogLevel = "log_level" | ||||||
|  |  | ||||||
| // APIController main controller which connects to the authentik api via http and ws | // APIController main controller which connects to the authentik api via http and ws | ||||||
| @ -42,6 +44,7 @@ type APIController struct { | |||||||
| 	lastWsReconnect     time.Time | 	lastWsReconnect     time.Time | ||||||
| 	wsIsReconnecting    bool | 	wsIsReconnecting    bool | ||||||
| 	wsBackoffMultiplier int | 	wsBackoffMultiplier int | ||||||
|  | 	wsHandlers          []WSHandler | ||||||
| 	refreshHandlers     []func() | 	refreshHandlers     []func() | ||||||
|  |  | ||||||
| 	instanceUUID uuid.UUID | 	instanceUUID uuid.UUID | ||||||
| @ -106,6 +109,7 @@ func NewAPIController(akURL url.URL, token string) *APIController { | |||||||
| 		reloadOffset:        time.Duration(rand.Intn(10)) * time.Second, | 		reloadOffset:        time.Duration(rand.Intn(10)) * time.Second, | ||||||
| 		instanceUUID:        uuid.New(), | 		instanceUUID:        uuid.New(), | ||||||
| 		Outpost:             outpost, | 		Outpost:             outpost, | ||||||
|  | 		wsHandlers:          []WSHandler{}, | ||||||
| 		wsBackoffMultiplier: 1, | 		wsBackoffMultiplier: 1, | ||||||
| 		refreshHandlers:     make([]func(), 0), | 		refreshHandlers:     make([]func(), 0), | ||||||
| 	} | 	} | ||||||
| @ -156,6 +160,10 @@ func (a *APIController) AddRefreshHandler(handler func()) { | |||||||
| 	a.refreshHandlers = append(a.refreshHandlers, handler) | 	a.refreshHandlers = append(a.refreshHandlers, handler) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (a *APIController) AddWSHandler(handler WSHandler) { | ||||||
|  | 	a.wsHandlers = append(a.wsHandlers, handler) | ||||||
|  | } | ||||||
|  |  | ||||||
| func (a *APIController) OnRefresh() error { | func (a *APIController) OnRefresh() error { | ||||||
| 	// Because we don't know the outpost UUID, we simply do a list and pick the first | 	// Because we don't know the outpost UUID, we simply do a list and pick the first | ||||||
| 	// The service account this token belongs to should only have access to a single outpost | 	// The service account this token belongs to should only have access to a single outpost | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| package ak | package ak | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| @ -145,6 +146,10 @@ func (ac *APIController) startWSHandler() { | |||||||
| 					"build":        constants.BUILD("tagged"), | 					"build":        constants.BUILD("tagged"), | ||||||
| 				}).SetToCurrentTime() | 				}).SetToCurrentTime() | ||||||
| 			} | 			} | ||||||
|  | 		} else if wsMsg.Instruction == WebsocketInstructionProviderSpecific { | ||||||
|  | 			for _, h := range ac.wsHandlers { | ||||||
|  | 				h(context.Background(), wsMsg.Args) | ||||||
|  | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -9,6 +9,8 @@ const ( | |||||||
| 	WebsocketInstructionHello websocketInstruction = 1 | 	WebsocketInstructionHello websocketInstruction = 1 | ||||||
| 	// WebsocketInstructionTriggerUpdate Code received to trigger a config update | 	// WebsocketInstructionTriggerUpdate Code received to trigger a config update | ||||||
| 	WebsocketInstructionTriggerUpdate websocketInstruction = 2 | 	WebsocketInstructionTriggerUpdate websocketInstruction = 2 | ||||||
|  | 	// WebsocketInstructionProviderSpecific Code received to trigger some provider specific function | ||||||
|  | 	WebsocketInstructionProviderSpecific websocketInstruction = 3 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type websocketMessage struct { | type websocketMessage struct { | ||||||
|  | |||||||
| @ -280,7 +280,9 @@ func (a *Application) handleSignOut(rw http.ResponseWriter, r *http.Request) { | |||||||
| 		"id_token_hint": []string{cc.RawToken}, | 		"id_token_hint": []string{cc.RawToken}, | ||||||
| 	} | 	} | ||||||
| 	redirect += "?" + uv.Encode() | 	redirect += "?" + uv.Encode() | ||||||
| 	err = a.Logout(r.Context(), cc.Sub) | 	err = a.Logout(r.Context(), func(c Claims) bool { | ||||||
|  | 		return c.Sub == cc.Sub | ||||||
|  | 	}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		a.log.WithError(err).Warning("failed to logout of other sessions") | 		a.log.WithError(err).Warning("failed to logout of other sessions") | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -11,10 +11,11 @@ type Claims struct { | |||||||
| 	Exp               int          `json:"exp"` | 	Exp               int          `json:"exp"` | ||||||
| 	Email             string       `json:"email"` | 	Email             string       `json:"email"` | ||||||
| 	Verified          bool         `json:"email_verified"` | 	Verified          bool         `json:"email_verified"` | ||||||
| 	Proxy             *ProxyClaims `json:"ak_proxy"` |  | ||||||
| 	Name              string       `json:"name"` | 	Name              string       `json:"name"` | ||||||
| 	PreferredUsername string       `json:"preferred_username"` | 	PreferredUsername string       `json:"preferred_username"` | ||||||
| 	Groups            []string     `json:"groups"` | 	Groups            []string     `json:"groups"` | ||||||
|  | 	Sid               string       `json:"sid"` | ||||||
|  | 	Proxy             *ProxyClaims `json:"ak_proxy"` | ||||||
|  |  | ||||||
| 	RawToken string | 	RawToken string | ||||||
| } | } | ||||||
|  | |||||||
| @ -88,7 +88,7 @@ func (a *Application) getAllCodecs() []securecookie.Codec { | |||||||
| 	return cs | 	return cs | ||||||
| } | } | ||||||
|  |  | ||||||
| func (a *Application) Logout(ctx context.Context, sub string) error { | func (a *Application) Logout(ctx context.Context, filter func(c Claims) bool) error { | ||||||
| 	if _, ok := a.sessions.(*sessions.FilesystemStore); ok { | 	if _, ok := a.sessions.(*sessions.FilesystemStore); ok { | ||||||
| 		files, err := os.ReadDir(os.TempDir()) | 		files, err := os.ReadDir(os.TempDir()) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @ -118,7 +118,7 @@ func (a *Application) Logout(ctx context.Context, sub string) error { | |||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| 			claims := s.Values[constants.SessionClaims].(Claims) | 			claims := s.Values[constants.SessionClaims].(Claims) | ||||||
| 			if claims.Sub == sub { | 			if filter(claims) { | ||||||
| 				a.log.WithField("path", fullPath).Trace("deleting session") | 				a.log.WithField("path", fullPath).Trace("deleting session") | ||||||
| 				err := os.Remove(fullPath) | 				err := os.Remove(fullPath) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| @ -153,7 +153,7 @@ func (a *Application) Logout(ctx context.Context, sub string) error { | |||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| 			claims := c.(Claims) | 			claims := c.(Claims) | ||||||
| 			if claims.Sub == sub { | 			if filter(claims) { | ||||||
| 				a.log.WithField("key", key).Trace("deleting session") | 				a.log.WithField("key", key).Trace("deleting session") | ||||||
| 				_, err := client.Del(ctx, key).Result() | 				_, err := client.Del(ctx, key).Result() | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
|  | |||||||
| @ -65,6 +65,7 @@ func NewProxyServer(ac *ak.APIController) *ProxyServer { | |||||||
| 	globalMux.PathPrefix("/outpost.goauthentik.io/static").HandlerFunc(s.HandleStatic) | 	globalMux.PathPrefix("/outpost.goauthentik.io/static").HandlerFunc(s.HandleStatic) | ||||||
| 	globalMux.Path("/outpost.goauthentik.io/ping").HandlerFunc(sentryutils.SentryNoSample(s.HandlePing)) | 	globalMux.Path("/outpost.goauthentik.io/ping").HandlerFunc(sentryutils.SentryNoSample(s.HandlePing)) | ||||||
| 	rootMux.PathPrefix("/").HandlerFunc(s.Handle) | 	rootMux.PathPrefix("/").HandlerFunc(s.Handle) | ||||||
|  | 	ac.AddWSHandler(s.handleWSMessage) | ||||||
| 	return s | 	return s | ||||||
| } | } | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										49
									
								
								internal/outpost/proxyv2/ws.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								internal/outpost/proxyv2/ws.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,49 @@ | |||||||
|  | package proxyv2 | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  |  | ||||||
|  | 	"github.com/mitchellh/mapstructure" | ||||||
|  | 	"goauthentik.io/internal/outpost/proxyv2/application" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type WSProviderSubType string | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	WSProviderSubTypeLogout WSProviderSubType = "logout" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type WSProviderMsg struct { | ||||||
|  | 	SubType   WSProviderSubType `mapstructure:"sub_type"` | ||||||
|  | 	SessionID string            `mapstructure:"session_id"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ParseWSProvider(args map[string]interface{}) (*WSProviderMsg, error) { | ||||||
|  | 	msg := &WSProviderMsg{} | ||||||
|  | 	err := mapstructure.Decode(args, &msg) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return msg, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ps *ProxyServer) handleWSMessage(ctx context.Context, args map[string]interface{}) { | ||||||
|  | 	msg, err := ParseWSProvider(args) | ||||||
|  | 	if err != nil { | ||||||
|  | 		ps.log.WithError(err).Warning("invalid provider-specific ws message") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	switch msg.SubType { | ||||||
|  | 	case WSProviderSubTypeLogout: | ||||||
|  | 		for _, p := range ps.apps { | ||||||
|  | 			err := p.Logout(ctx, func(c application.Claims) bool { | ||||||
|  | 				return c.Sid == msg.SessionID | ||||||
|  | 			}) | ||||||
|  | 			if err != nil { | ||||||
|  | 				ps.log.WithField("provider", p.Host).WithError(err).Warning("failed to logout") | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	default: | ||||||
|  | 		ps.log.WithField("sub_type", msg.SubType).Warning("invalid sub_type") | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user
	 Jens L
					Jens L