providers/proxy: improve SLO by backchannel logging out sessions (#7099)
* outposts: add support for provider-specific websocket messages Signed-off-by: Jens Langhammer <jens@goauthentik.io> * providers/proxy: add custom signal on logout to logout in provider Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		| @ -27,6 +27,9 @@ class WebsocketMessageInstruction(IntEnum): | ||||
|     # Message sent by us to trigger an Update | ||||
|     TRIGGER_UPDATE = 2 | ||||
|  | ||||
|     # Provider specific message | ||||
|     PROVIDER_SPECIFIC = 3 | ||||
|  | ||||
|  | ||||
| @dataclass(slots=True) | ||||
| class WebsocketMessage: | ||||
| @ -131,3 +134,14 @@ class OutpostConsumer(AuthJsonConsumer): | ||||
|         self.send_json( | ||||
|             asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)) | ||||
|         ) | ||||
|  | ||||
|     def event_provider_specific(self, event): | ||||
|         """Event handler which can be called by provider-specific | ||||
|         implementations to send specific messages to the outpost""" | ||||
|         self.send_json( | ||||
|             asdict( | ||||
|                 WebsocketMessage( | ||||
|                     instruction=WebsocketMessageInstruction.PROVIDER_SPECIFIC, args=event | ||||
|                 ) | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
| @ -5,7 +5,6 @@ from socket import gethostname | ||||
| from typing import Any, Optional | ||||
| from urllib.parse import urlparse | ||||
|  | ||||
| import yaml | ||||
| from asgiref.sync import async_to_sync | ||||
| from channels.layers import get_channel_layer | ||||
| from django.core.cache import cache | ||||
| @ -16,6 +15,7 @@ from docker.constants import DEFAULT_UNIX_SOCKET | ||||
| from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME | ||||
| from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION | ||||
| from structlog.stdlib import get_logger | ||||
| from yaml import safe_load | ||||
|  | ||||
| from authentik.events.monitored_tasks import ( | ||||
|     MonitoredTask, | ||||
| @ -279,7 +279,7 @@ def outpost_connection_discovery(self: MonitoredTask): | ||||
|             with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig: | ||||
|                 KubernetesServiceConnection.objects.create( | ||||
|                     name=kubeconfig_local_name, | ||||
|                     kubeconfig=yaml.safe_load(_kubeconfig), | ||||
|                     kubeconfig=safe_load(_kubeconfig), | ||||
|                 ) | ||||
|     unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path | ||||
|     socket = Path(unix_socket_path) | ||||
|  | ||||
| @ -9,3 +9,7 @@ class AuthentikProviderProxyConfig(ManagedAppConfig): | ||||
|     label = "authentik_providers_proxy" | ||||
|     verbose_name = "authentik Providers.Proxy" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_load_providers_proxy_signals(self): | ||||
|         """Load proxy signals""" | ||||
|         self.import_module("authentik.providers.proxy.signals") | ||||
|  | ||||
							
								
								
									
										20
									
								
								authentik/providers/proxy/signals.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								authentik/providers/proxy/signals.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,20 @@ | ||||
| """Proxy provider signals""" | ||||
| from django.contrib.auth.signals import user_logged_out | ||||
| from django.db.models.signals import pre_delete | ||||
| from django.dispatch import receiver | ||||
| from django.http import HttpRequest | ||||
|  | ||||
| from authentik.core.models import AuthenticatedSession, User | ||||
| from authentik.providers.proxy.tasks import proxy_on_logout | ||||
|  | ||||
|  | ||||
| @receiver(user_logged_out) | ||||
| def logout_proxy_revoke_direct(sender: type[User], request: HttpRequest, **_): | ||||
|     """Catch logout by direct logout and forward to proxy providers""" | ||||
|     proxy_on_logout.delay(request.session.session_key) | ||||
|  | ||||
|  | ||||
| @receiver(pre_delete, sender=AuthenticatedSession) | ||||
| def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | ||||
|     """Catch logout by expiring sessions being deleted""" | ||||
|     proxy_on_logout.delay(instance.session_key) | ||||
| @ -1,6 +1,9 @@ | ||||
| """proxy provider tasks""" | ||||
| from asgiref.sync import async_to_sync | ||||
| from channels.layers import get_channel_layer | ||||
| from django.db import DatabaseError, InternalError, ProgrammingError | ||||
|  | ||||
| from authentik.outposts.models import Outpost, OutpostState, OutpostType | ||||
| from authentik.providers.proxy.models import ProxyProvider | ||||
| from authentik.root.celery import CELERY_APP | ||||
|  | ||||
| @ -13,3 +16,20 @@ def proxy_set_defaults(): | ||||
|     for provider in ProxyProvider.objects.all(): | ||||
|         provider.set_oauth_defaults() | ||||
|         provider.save() | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task() | ||||
| def proxy_on_logout(session_id: str): | ||||
|     """Update outpost instances connected to a single outpost""" | ||||
|     layer = get_channel_layer() | ||||
|     for outpost in Outpost.objects.filter(type=OutpostType.PROXY): | ||||
|         for state in OutpostState.for_outpost(outpost): | ||||
|             for channel in state.channel_ids: | ||||
|                 async_to_sync(layer.send)( | ||||
|                     channel, | ||||
|                     { | ||||
|                         "type": "event.provider.specific", | ||||
|                         "sub_type": "logout", | ||||
|                         "session_id": session_id, | ||||
|                     }, | ||||
|                 ) | ||||
|  | ||||
| @ -15,6 +15,7 @@ entries: | ||||
|         # This mapping is used by the authentik proxy. It passes extra user attributes, | ||||
|         # which are used for example for the HTTP-Basic Authentication mapping. | ||||
|         return { | ||||
|             "sid": request.http_request.session.session_key, | ||||
|             "ak_proxy": { | ||||
|                 "user_attributes": request.user.group_attributes(request), | ||||
|                 "is_superuser": request.user.is_superuser, | ||||
|  | ||||
| @ -22,6 +22,8 @@ import ( | ||||
| 	log "github.com/sirupsen/logrus" | ||||
| ) | ||||
|  | ||||
| type WSHandler func(ctx context.Context, args map[string]interface{}) | ||||
|  | ||||
| const ConfigLogLevel = "log_level" | ||||
|  | ||||
| // APIController main controller which connects to the authentik api via http and ws | ||||
| @ -42,6 +44,7 @@ type APIController struct { | ||||
| 	lastWsReconnect     time.Time | ||||
| 	wsIsReconnecting    bool | ||||
| 	wsBackoffMultiplier int | ||||
| 	wsHandlers          []WSHandler | ||||
| 	refreshHandlers     []func() | ||||
|  | ||||
| 	instanceUUID uuid.UUID | ||||
| @ -106,6 +109,7 @@ func NewAPIController(akURL url.URL, token string) *APIController { | ||||
| 		reloadOffset:        time.Duration(rand.Intn(10)) * time.Second, | ||||
| 		instanceUUID:        uuid.New(), | ||||
| 		Outpost:             outpost, | ||||
| 		wsHandlers:          []WSHandler{}, | ||||
| 		wsBackoffMultiplier: 1, | ||||
| 		refreshHandlers:     make([]func(), 0), | ||||
| 	} | ||||
| @ -156,6 +160,10 @@ func (a *APIController) AddRefreshHandler(handler func()) { | ||||
| 	a.refreshHandlers = append(a.refreshHandlers, handler) | ||||
| } | ||||
|  | ||||
| func (a *APIController) AddWSHandler(handler WSHandler) { | ||||
| 	a.wsHandlers = append(a.wsHandlers, handler) | ||||
| } | ||||
|  | ||||
| func (a *APIController) OnRefresh() error { | ||||
| 	// Because we don't know the outpost UUID, we simply do a list and pick the first | ||||
| 	// The service account this token belongs to should only have access to a single outpost | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| package ak | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| @ -145,6 +146,10 @@ func (ac *APIController) startWSHandler() { | ||||
| 					"build":        constants.BUILD("tagged"), | ||||
| 				}).SetToCurrentTime() | ||||
| 			} | ||||
| 		} else if wsMsg.Instruction == WebsocketInstructionProviderSpecific { | ||||
| 			for _, h := range ac.wsHandlers { | ||||
| 				h(context.Background(), wsMsg.Args) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -9,6 +9,8 @@ const ( | ||||
| 	WebsocketInstructionHello websocketInstruction = 1 | ||||
| 	// WebsocketInstructionTriggerUpdate Code received to trigger a config update | ||||
| 	WebsocketInstructionTriggerUpdate websocketInstruction = 2 | ||||
| 	// WebsocketInstructionProviderSpecific Code received to trigger some provider specific function | ||||
| 	WebsocketInstructionProviderSpecific websocketInstruction = 3 | ||||
| ) | ||||
|  | ||||
| type websocketMessage struct { | ||||
|  | ||||
| @ -280,7 +280,9 @@ func (a *Application) handleSignOut(rw http.ResponseWriter, r *http.Request) { | ||||
| 		"id_token_hint": []string{cc.RawToken}, | ||||
| 	} | ||||
| 	redirect += "?" + uv.Encode() | ||||
| 	err = a.Logout(r.Context(), cc.Sub) | ||||
| 	err = a.Logout(r.Context(), func(c Claims) bool { | ||||
| 		return c.Sub == cc.Sub | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		a.log.WithError(err).Warning("failed to logout of other sessions") | ||||
| 	} | ||||
|  | ||||
| @ -11,10 +11,11 @@ type Claims struct { | ||||
| 	Exp               int          `json:"exp"` | ||||
| 	Email             string       `json:"email"` | ||||
| 	Verified          bool         `json:"email_verified"` | ||||
| 	Proxy             *ProxyClaims `json:"ak_proxy"` | ||||
| 	Name              string       `json:"name"` | ||||
| 	PreferredUsername string       `json:"preferred_username"` | ||||
| 	Groups            []string     `json:"groups"` | ||||
| 	Sid               string       `json:"sid"` | ||||
| 	Proxy             *ProxyClaims `json:"ak_proxy"` | ||||
|  | ||||
| 	RawToken string | ||||
| } | ||||
|  | ||||
| @ -88,7 +88,7 @@ func (a *Application) getAllCodecs() []securecookie.Codec { | ||||
| 	return cs | ||||
| } | ||||
|  | ||||
| func (a *Application) Logout(ctx context.Context, sub string) error { | ||||
| func (a *Application) Logout(ctx context.Context, filter func(c Claims) bool) error { | ||||
| 	if _, ok := a.sessions.(*sessions.FilesystemStore); ok { | ||||
| 		files, err := os.ReadDir(os.TempDir()) | ||||
| 		if err != nil { | ||||
| @ -118,7 +118,7 @@ func (a *Application) Logout(ctx context.Context, sub string) error { | ||||
| 				continue | ||||
| 			} | ||||
| 			claims := s.Values[constants.SessionClaims].(Claims) | ||||
| 			if claims.Sub == sub { | ||||
| 			if filter(claims) { | ||||
| 				a.log.WithField("path", fullPath).Trace("deleting session") | ||||
| 				err := os.Remove(fullPath) | ||||
| 				if err != nil { | ||||
| @ -153,7 +153,7 @@ func (a *Application) Logout(ctx context.Context, sub string) error { | ||||
| 				continue | ||||
| 			} | ||||
| 			claims := c.(Claims) | ||||
| 			if claims.Sub == sub { | ||||
| 			if filter(claims) { | ||||
| 				a.log.WithField("key", key).Trace("deleting session") | ||||
| 				_, err := client.Del(ctx, key).Result() | ||||
| 				if err != nil { | ||||
|  | ||||
| @ -65,6 +65,7 @@ func NewProxyServer(ac *ak.APIController) *ProxyServer { | ||||
| 	globalMux.PathPrefix("/outpost.goauthentik.io/static").HandlerFunc(s.HandleStatic) | ||||
| 	globalMux.Path("/outpost.goauthentik.io/ping").HandlerFunc(sentryutils.SentryNoSample(s.HandlePing)) | ||||
| 	rootMux.PathPrefix("/").HandlerFunc(s.Handle) | ||||
| 	ac.AddWSHandler(s.handleWSMessage) | ||||
| 	return s | ||||
| } | ||||
|  | ||||
|  | ||||
							
								
								
									
										49
									
								
								internal/outpost/proxyv2/ws.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								internal/outpost/proxyv2/ws.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,49 @@ | ||||
| package proxyv2 | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
|  | ||||
| 	"github.com/mitchellh/mapstructure" | ||||
| 	"goauthentik.io/internal/outpost/proxyv2/application" | ||||
| ) | ||||
|  | ||||
| type WSProviderSubType string | ||||
|  | ||||
| const ( | ||||
| 	WSProviderSubTypeLogout WSProviderSubType = "logout" | ||||
| ) | ||||
|  | ||||
| type WSProviderMsg struct { | ||||
| 	SubType   WSProviderSubType `mapstructure:"sub_type"` | ||||
| 	SessionID string            `mapstructure:"session_id"` | ||||
| } | ||||
|  | ||||
| func ParseWSProvider(args map[string]interface{}) (*WSProviderMsg, error) { | ||||
| 	msg := &WSProviderMsg{} | ||||
| 	err := mapstructure.Decode(args, &msg) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return msg, nil | ||||
| } | ||||
|  | ||||
| func (ps *ProxyServer) handleWSMessage(ctx context.Context, args map[string]interface{}) { | ||||
| 	msg, err := ParseWSProvider(args) | ||||
| 	if err != nil { | ||||
| 		ps.log.WithError(err).Warning("invalid provider-specific ws message") | ||||
| 		return | ||||
| 	} | ||||
| 	switch msg.SubType { | ||||
| 	case WSProviderSubTypeLogout: | ||||
| 		for _, p := range ps.apps { | ||||
| 			err := p.Logout(ctx, func(c application.Claims) bool { | ||||
| 				return c.Sid == msg.SessionID | ||||
| 			}) | ||||
| 			if err != nil { | ||||
| 				ps.log.WithField("provider", p.Host).WithError(err).Warning("failed to logout") | ||||
| 			} | ||||
| 		} | ||||
| 	default: | ||||
| 		ps.log.WithField("sub_type", msg.SubType).Warning("invalid sub_type") | ||||
| 	} | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 Jens L
					Jens L