outposts: Refactor session end signal and add LDAP support (#14539)
* outpost: promote session end signal to non-provider specific Signed-off-by: Jens Langhammer <jens@goauthentik.io> * implement server-side logout in ldap Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix previous import Signed-off-by: Jens Langhammer <jens@goauthentik.io> * use better retry logic Signed-off-by: Jens Langhammer <jens@goauthentik.io> * log Signed-off-by: Jens Langhammer <jens@goauthentik.io> * make more generic if we switch from ws to something else Signed-off-by: Jens Langhammer <jens@goauthentik.io> * make it possible to e2e test WS Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix ldap session id Signed-off-by: Jens Langhammer <jens@goauthentik.io> * ok I actually need to go to bed this took me an hour to fix Signed-off-by: Jens Langhammer <jens@goauthentik.io> * format; add ldap test Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix leftover state Signed-off-by: Jens Langhammer <jens@goauthentik.io> * remove thread Signed-off-by: Jens Langhammer <jens@goauthentik.io> * use ws base for radius Signed-off-by: Jens Langhammer <jens@goauthentik.io> * separate test utils Signed-off-by: Jens Langhammer <jens@goauthentik.io> * rename Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix missing super calls Signed-off-by: Jens Langhammer <jens@goauthentik.io> * websocket tests with browser 🎉 Signed-off-by: Jens Langhammer <jens@goauthentik.io> * add proxy test for sign out Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix install_id issue with channels tests Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix proxy basic auth test Signed-off-by: Jens Langhammer <jens@goauthentik.io> * big code dedupe Signed-off-by: Jens Langhammer <jens@goauthentik.io> * allow passing go build args Signed-off-by: Jens Langhammer <jens@goauthentik.io> * improve waiting for outpost Signed-off-by: Jens Langhammer <jens@goauthentik.io> * rewrite ldap tests Signed-off-by: Jens Langhammer <jens@goauthentik.io> * ok actually fix the tests Signed-off-by: Jens Langhammer <jens@goauthentik.io> * undo a couple things that need more time to cook Signed-off-by: Jens Langhammer <jens@goauthentik.io> * remove unused lockfile-lint dependency since we use a shell script and SFE does not have a lockfile Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix session id for ldap Signed-off-by: Jens Langhammer <jens@goauthentik.io> * fix missing createTimestamp and modifyTimestamp ldap attributes closes #10474 Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
		| @ -11,3 +11,4 @@ blueprints/local | ||||
| !gen-ts-api/node_modules | ||||
| !gen-ts-api/dist/** | ||||
| !gen-go-api/ | ||||
| .venv | ||||
|  | ||||
| @ -37,6 +37,9 @@ class WebsocketMessageInstruction(IntEnum): | ||||
|     # Provider specific message | ||||
|     PROVIDER_SPECIFIC = 3 | ||||
|  | ||||
|     # Session ended | ||||
|     SESSION_END = 4 | ||||
|  | ||||
|  | ||||
| @dataclass(slots=True) | ||||
| class WebsocketMessage: | ||||
| @ -145,6 +148,14 @@ class OutpostConsumer(JsonWebsocketConsumer): | ||||
|             asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)) | ||||
|         ) | ||||
|  | ||||
|     def event_session_end(self, event): | ||||
|         """Event handler which is called when a session is ended""" | ||||
|         self.send_json( | ||||
|             asdict( | ||||
|                 WebsocketMessage(instruction=WebsocketMessageInstruction.SESSION_END, args=event) | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def event_provider_specific(self, event): | ||||
|         """Event handler which can be called by provider-specific | ||||
|         implementations to send specific messages to the outpost""" | ||||
|  | ||||
| @ -1,17 +1,24 @@ | ||||
| """authentik outpost signals""" | ||||
|  | ||||
| from django.contrib.auth.signals import user_logged_out | ||||
| from django.core.cache import cache | ||||
| from django.db.models import Model | ||||
| from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save | ||||
| from django.dispatch import receiver | ||||
| from django.http import HttpRequest | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.brands.models import Brand | ||||
| from authentik.core.models import Provider | ||||
| from authentik.core.models import AuthenticatedSession, Provider, User | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.lib.utils.reflection import class_to_path | ||||
| from authentik.outposts.models import Outpost, OutpostServiceConnection | ||||
| from authentik.outposts.tasks import CACHE_KEY_OUTPOST_DOWN, outpost_controller, outpost_post_save | ||||
| from authentik.outposts.tasks import ( | ||||
|     CACHE_KEY_OUTPOST_DOWN, | ||||
|     outpost_controller, | ||||
|     outpost_post_save, | ||||
|     outpost_session_end, | ||||
| ) | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| UPDATE_TRIGGERING_MODELS = ( | ||||
| @ -73,3 +80,17 @@ def pre_delete_cleanup(sender, instance: Outpost, **_): | ||||
|     instance.user.delete() | ||||
|     cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance) | ||||
|     outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) | ||||
|  | ||||
|  | ||||
| @receiver(user_logged_out) | ||||
| def logout_revoke_direct(sender: type[User], request: HttpRequest, **_): | ||||
|     """Catch logout by direct logout and forward to providers""" | ||||
|     if not request.session or not request.session.session_key: | ||||
|         return | ||||
|     outpost_session_end.delay(request.session.session_key) | ||||
|  | ||||
|  | ||||
| @receiver(pre_delete, sender=AuthenticatedSession) | ||||
| def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | ||||
|     """Catch logout by expiring sessions being deleted""" | ||||
|     outpost_session_end.delay(instance.session.session_key) | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| """outpost tasks""" | ||||
|  | ||||
| from hashlib import sha256 | ||||
| from os import R_OK, access | ||||
| from pathlib import Path | ||||
| from socket import gethostname | ||||
| @ -49,6 +50,11 @@ LOGGER = get_logger() | ||||
| CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s" | ||||
|  | ||||
|  | ||||
| def hash_session_key(session_key: str) -> str: | ||||
|     """Hash the session key for sending session end signals""" | ||||
|     return sha256(session_key.encode("ascii")).hexdigest() | ||||
|  | ||||
|  | ||||
| def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None: | ||||
|     """Get a controller for the outpost, when a service connection is defined""" | ||||
|     if not outpost.service_connection: | ||||
| @ -289,3 +295,20 @@ def outpost_connection_discovery(self: SystemTask): | ||||
|                 url=unix_socket_path, | ||||
|             ) | ||||
|     self.set_status(TaskStatus.SUCCESSFUL, *messages) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task() | ||||
| def outpost_session_end(session_id: str): | ||||
|     """Update outpost instances connected to a single outpost""" | ||||
|     layer = get_channel_layer() | ||||
|     hashed_session_id = hash_session_key(session_id) | ||||
|     for outpost in Outpost.objects.all(): | ||||
|         LOGGER.info("Sending session end signal to outpost", outpost=outpost) | ||||
|         group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} | ||||
|         async_to_sync(layer.group_send)( | ||||
|             group, | ||||
|             { | ||||
|                 "type": "event.session.end", | ||||
|                 "session_id": hashed_session_id, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
| @ -1,23 +0,0 @@ | ||||
| """Proxy provider signals""" | ||||
|  | ||||
| from django.contrib.auth.signals import user_logged_out | ||||
| from django.db.models.signals import pre_delete | ||||
| from django.dispatch import receiver | ||||
| from django.http import HttpRequest | ||||
|  | ||||
| from authentik.core.models import AuthenticatedSession, User | ||||
| from authentik.providers.proxy.tasks import proxy_on_logout | ||||
|  | ||||
|  | ||||
| @receiver(user_logged_out) | ||||
| def logout_proxy_revoke_direct(sender: type[User], request: HttpRequest, **_): | ||||
|     """Catch logout by direct logout and forward to proxy providers""" | ||||
|     if not request.session or not request.session.session_key: | ||||
|         return | ||||
|     proxy_on_logout.delay(request.session.session_key) | ||||
|  | ||||
|  | ||||
| @receiver(pre_delete, sender=AuthenticatedSession) | ||||
| def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | ||||
|     """Catch logout by expiring sessions being deleted""" | ||||
|     proxy_on_logout.delay(instance.session.session_key) | ||||
| @ -1,26 +0,0 @@ | ||||
| """proxy provider tasks""" | ||||
|  | ||||
| from asgiref.sync import async_to_sync | ||||
| from channels.layers import get_channel_layer | ||||
|  | ||||
| from authentik.outposts.consumer import OUTPOST_GROUP | ||||
| from authentik.outposts.models import Outpost, OutpostType | ||||
| from authentik.providers.oauth2.id_token import hash_session_key | ||||
| from authentik.root.celery import CELERY_APP | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task() | ||||
| def proxy_on_logout(session_id: str): | ||||
|     """Update outpost instances connected to a single outpost""" | ||||
|     layer = get_channel_layer() | ||||
|     hashed_session_id = hash_session_key(session_id) | ||||
|     for outpost in Outpost.objects.filter(type=OutpostType.PROXY): | ||||
|         group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} | ||||
|         async_to_sync(layer.group_send)( | ||||
|             group, | ||||
|             { | ||||
|                 "type": "event.provider.specific", | ||||
|                 "sub_type": "logout", | ||||
|                 "session_id": hashed_session_id, | ||||
|             }, | ||||
|         ) | ||||
							
								
								
									
										1
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								go.mod
									
									
									
									
									
								
							| @ -4,6 +4,7 @@ go 1.24.0 | ||||
|  | ||||
| require ( | ||||
| 	beryju.io/ldap v0.1.0 | ||||
| 	github.com/avast/retry-go/v4 v4.6.1 | ||||
| 	github.com/coreos/go-oidc/v3 v3.14.1 | ||||
| 	github.com/getsentry/sentry-go v0.33.0 | ||||
| 	github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1 | ||||
|  | ||||
							
								
								
									
										2
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
									
									
									
									
								
							| @ -41,6 +41,8 @@ github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7V | ||||
| github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= | ||||
| github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= | ||||
| github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= | ||||
| github.com/avast/retry-go/v4 v4.6.1 h1:VkOLRubHdisGrHnTu89g08aQEWEgRU7LVEop3GbIcMk= | ||||
| github.com/avast/retry-go/v4 v4.6.1/go.mod h1:V6oF8njAwxJ5gRo1Q7Cxab24xs5NCWZBeaHHBklR8mA= | ||||
| github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= | ||||
| github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= | ||||
| github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= | ||||
|  | ||||
| @ -13,6 +13,7 @@ import ( | ||||
| 	"syscall" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/avast/retry-go/v4" | ||||
| 	"github.com/getsentry/sentry-go" | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.com/gorilla/websocket" | ||||
| @ -25,8 +26,6 @@ import ( | ||||
| 	"goauthentik.io/internal/utils/web" | ||||
| ) | ||||
|  | ||||
| type WSHandler func(ctx context.Context, args map[string]interface{}) | ||||
|  | ||||
| const ConfigLogLevel = "log_level" | ||||
|  | ||||
| // APIController main controller which connects to the authentik api via http and ws | ||||
| @ -43,12 +42,11 @@ type APIController struct { | ||||
|  | ||||
| 	reloadOffset time.Duration | ||||
|  | ||||
| 	wsConn              *websocket.Conn | ||||
| 	lastWsReconnect     time.Time | ||||
| 	wsIsReconnecting    bool | ||||
| 	wsBackoffMultiplier int | ||||
| 	wsHandlers          []WSHandler | ||||
| 	refreshHandlers     []func() | ||||
| 	eventConn        *websocket.Conn | ||||
| 	lastWsReconnect  time.Time | ||||
| 	wsIsReconnecting bool | ||||
| 	eventHandlers    []EventHandler | ||||
| 	refreshHandlers  []func() | ||||
|  | ||||
| 	instanceUUID uuid.UUID | ||||
| } | ||||
| @ -83,20 +81,19 @@ func NewAPIController(akURL url.URL, token string) *APIController { | ||||
|  | ||||
| 	// Because we don't know the outpost UUID, we simply do a list and pick the first | ||||
| 	// The service account this token belongs to should only have access to a single outpost | ||||
| 	var outposts *api.PaginatedOutpostList | ||||
| 	var err error | ||||
| 	for { | ||||
| 		outposts, _, err = apiClient.OutpostsApi.OutpostsInstancesList(context.Background()).Execute() | ||||
|  | ||||
| 		if err == nil { | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		log.WithError(err).Error("Failed to fetch outpost configuration, retrying in 3 seconds") | ||||
| 		time.Sleep(time.Second * 3) | ||||
| 	} | ||||
| 	outposts, _ := retry.DoWithData[*api.PaginatedOutpostList]( | ||||
| 		func() (*api.PaginatedOutpostList, error) { | ||||
| 			outposts, _, err := apiClient.OutpostsApi.OutpostsInstancesList(context.Background()).Execute() | ||||
| 			return outposts, err | ||||
| 		}, | ||||
| 		retry.Attempts(0), | ||||
| 		retry.Delay(time.Second*3), | ||||
| 		retry.OnRetry(func(attempt uint, err error) { | ||||
| 			log.WithError(err).Error("Failed to fetch outpost configuration, retrying in 3 seconds") | ||||
| 		}), | ||||
| 	) | ||||
| 	if len(outposts.Results) < 1 { | ||||
| 		panic("No outposts found with given token, ensure the given token corresponds to an authenitk Outpost") | ||||
| 		log.Panic("No outposts found with given token, ensure the given token corresponds to an authenitk Outpost") | ||||
| 	} | ||||
| 	outpost := outposts.Results[0] | ||||
|  | ||||
| @ -119,17 +116,16 @@ func NewAPIController(akURL url.URL, token string) *APIController { | ||||
| 		token:  token, | ||||
| 		logger: log, | ||||
|  | ||||
| 		reloadOffset:        time.Duration(rand.Intn(10)) * time.Second, | ||||
| 		instanceUUID:        uuid.New(), | ||||
| 		Outpost:             outpost, | ||||
| 		wsHandlers:          []WSHandler{}, | ||||
| 		wsBackoffMultiplier: 1, | ||||
| 		refreshHandlers:     make([]func(), 0), | ||||
| 		reloadOffset:    time.Duration(rand.Intn(10)) * time.Second, | ||||
| 		instanceUUID:    uuid.New(), | ||||
| 		Outpost:         outpost, | ||||
| 		eventHandlers:   []EventHandler{}, | ||||
| 		refreshHandlers: make([]func(), 0), | ||||
| 	} | ||||
| 	ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset") | ||||
| 	err = ac.initWS(akURL, outpost.Pk) | ||||
| 	err = ac.initEvent(akURL, outpost.Pk) | ||||
| 	if err != nil { | ||||
| 		go ac.reconnectWS() | ||||
| 		go ac.recentEvents() | ||||
| 	} | ||||
| 	ac.configureRefreshSignal() | ||||
| 	return ac | ||||
| @ -200,7 +196,7 @@ func (a *APIController) OnRefresh() error { | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (a *APIController) getWebsocketPingArgs() map[string]interface{} { | ||||
| func (a *APIController) getEventPingArgs() map[string]interface{} { | ||||
| 	args := map[string]interface{}{ | ||||
| 		"version":        constants.VERSION, | ||||
| 		"buildHash":      constants.BUILD(""), | ||||
| @ -226,12 +222,12 @@ func (a *APIController) StartBackgroundTasks() error { | ||||
| 		"build":        constants.BUILD(""), | ||||
| 	}).Set(1) | ||||
| 	go func() { | ||||
| 		a.logger.Debug("Starting WS Handler...") | ||||
| 		a.startWSHandler() | ||||
| 		a.logger.Debug("Starting Event Handler...") | ||||
| 		a.startEventHandler() | ||||
| 	}() | ||||
| 	go func() { | ||||
| 		a.logger.Debug("Starting WS Health notifier...") | ||||
| 		a.startWSHealth() | ||||
| 		a.logger.Debug("Starting Event health notifier...") | ||||
| 		a.startEventHealth() | ||||
| 	}() | ||||
| 	go func() { | ||||
| 		a.logger.Debug("Starting Interval updater...") | ||||
|  | ||||
| @ -11,6 +11,7 @@ import ( | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/avast/retry-go/v4" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"github.com/prometheus/client_golang/prometheus" | ||||
| 	"goauthentik.io/internal/config" | ||||
| @ -30,7 +31,7 @@ func (ac *APIController) getWebsocketURL(akURL url.URL, outpostUUID string, quer | ||||
| 	return wsUrl | ||||
| } | ||||
| 
 | ||||
| func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error { | ||||
| func (ac *APIController) initEvent(akURL url.URL, outpostUUID string) error { | ||||
| 	query := akURL.Query() | ||||
| 	query.Set("instance_uuid", ac.instanceUUID.String()) | ||||
| 
 | ||||
| @ -57,19 +58,19 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	ac.wsConn = ws | ||||
| 	ac.eventConn = ws | ||||
| 	// Send hello message with our version | ||||
| 	msg := websocketMessage{ | ||||
| 		Instruction: WebsocketInstructionHello, | ||||
| 		Args:        ac.getWebsocketPingArgs(), | ||||
| 	msg := Event{ | ||||
| 		Instruction: EventKindHello, | ||||
| 		Args:        ac.getEventPingArgs(), | ||||
| 	} | ||||
| 	err = ws.WriteJSON(msg) | ||||
| 	if err != nil { | ||||
| 		ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithError(err).Warning("Failed to hello to authentik") | ||||
| 		ac.logger.WithField("logger", "authentik.outpost.events").WithError(err).Warning("Failed to hello to authentik") | ||||
| 		return err | ||||
| 	} | ||||
| 	ac.lastWsReconnect = time.Now() | ||||
| 	ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID).Info("Successfully connected websocket") | ||||
| 	ac.logger.WithField("logger", "authentik.outpost.events").WithField("outpost", outpostUUID).Info("Successfully connected websocket") | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| @ -77,19 +78,19 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error { | ||||
| func (ac *APIController) Shutdown() { | ||||
| 	// Cleanly close the connection by sending a close message and then | ||||
| 	// waiting (with timeout) for the server to close the connection. | ||||
| 	err := ac.wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) | ||||
| 	err := ac.eventConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) | ||||
| 	if err != nil { | ||||
| 		ac.logger.WithError(err).Warning("failed to write close message") | ||||
| 		return | ||||
| 	} | ||||
| 	err = ac.wsConn.Close() | ||||
| 	err = ac.eventConn.Close() | ||||
| 	if err != nil { | ||||
| 		ac.logger.WithError(err).Warning("failed to close websocket") | ||||
| 	} | ||||
| 	ac.logger.Info("finished shutdown") | ||||
| } | ||||
| 
 | ||||
| func (ac *APIController) reconnectWS() { | ||||
| func (ac *APIController) recentEvents() { | ||||
| 	if ac.wsIsReconnecting { | ||||
| 		return | ||||
| 	} | ||||
| @ -100,46 +101,47 @@ func (ac *APIController) reconnectWS() { | ||||
| 		Path:   strings.ReplaceAll(ac.Client.GetConfig().Servers[0].URL, "api/v3", ""), | ||||
| 	} | ||||
| 	attempt := 1 | ||||
| 	for { | ||||
| 		q := u.Query() | ||||
| 		q.Set("attempt", strconv.Itoa(attempt)) | ||||
| 		u.RawQuery = q.Encode() | ||||
| 		err := ac.initWS(u, ac.Outpost.Pk) | ||||
| 		attempt += 1 | ||||
| 		if err != nil { | ||||
| 			ac.logger.Infof("waiting %d seconds to reconnect", ac.wsBackoffMultiplier) | ||||
| 			time.Sleep(time.Duration(ac.wsBackoffMultiplier) * time.Second) | ||||
| 			ac.wsBackoffMultiplier = ac.wsBackoffMultiplier * 2 | ||||
| 			// Limit to 300 seconds (5m) | ||||
| 			if ac.wsBackoffMultiplier >= 300 { | ||||
| 				ac.wsBackoffMultiplier = 300 | ||||
| 	_ = retry.Do( | ||||
| 		func() error { | ||||
| 			q := u.Query() | ||||
| 			q.Set("attempt", strconv.Itoa(attempt)) | ||||
| 			u.RawQuery = q.Encode() | ||||
| 			err := ac.initEvent(u, ac.Outpost.Pk) | ||||
| 			attempt += 1 | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} else { | ||||
| 			ac.wsIsReconnecting = false | ||||
| 			ac.wsBackoffMultiplier = 1 | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 			return nil | ||||
| 		}, | ||||
| 		retry.Delay(1*time.Second), | ||||
| 		retry.MaxDelay(5*time.Minute), | ||||
| 		retry.DelayType(retry.BackOffDelay), | ||||
| 		retry.Attempts(0), | ||||
| 		retry.OnRetry(func(attempt uint, err error) { | ||||
| 			ac.logger.Infof("waiting %d seconds to reconnect", attempt) | ||||
| 		}), | ||||
| 	) | ||||
| } | ||||
| 
 | ||||
| func (ac *APIController) startWSHandler() { | ||||
| 	logger := ac.logger.WithField("loop", "ws-handler") | ||||
| func (ac *APIController) startEventHandler() { | ||||
| 	logger := ac.logger.WithField("loop", "event-handler") | ||||
| 	for { | ||||
| 		var wsMsg websocketMessage | ||||
| 		if ac.wsConn == nil { | ||||
| 			go ac.reconnectWS() | ||||
| 		var wsMsg Event | ||||
| 		if ac.eventConn == nil { | ||||
| 			go ac.recentEvents() | ||||
| 			time.Sleep(time.Second * 5) | ||||
| 			continue | ||||
| 		} | ||||
| 		err := ac.wsConn.ReadJSON(&wsMsg) | ||||
| 		err := ac.eventConn.ReadJSON(&wsMsg) | ||||
| 		if err != nil { | ||||
| 			ConnectionStatus.With(prometheus.Labels{ | ||||
| 				"outpost_name": ac.Outpost.Name, | ||||
| 				"outpost_type": ac.Server.Type(), | ||||
| 				"uuid":         ac.instanceUUID.String(), | ||||
| 			}).Set(0) | ||||
| 			logger.WithError(err).Warning("ws read error") | ||||
| 			go ac.reconnectWS() | ||||
| 			logger.WithError(err).Warning("event read error") | ||||
| 			go ac.recentEvents() | ||||
| 			time.Sleep(time.Second * 5) | ||||
| 			continue | ||||
| 		} | ||||
| @ -149,7 +151,8 @@ func (ac *APIController) startWSHandler() { | ||||
| 			"uuid":         ac.instanceUUID.String(), | ||||
| 		}).Set(1) | ||||
| 		switch wsMsg.Instruction { | ||||
| 		case WebsocketInstructionTriggerUpdate: | ||||
| 		case EventKindAck: | ||||
| 		case EventKindTriggerUpdate: | ||||
| 			time.Sleep(ac.reloadOffset) | ||||
| 			logger.Debug("Got update trigger...") | ||||
| 			err := ac.OnRefresh() | ||||
| @ -164,30 +167,33 @@ func (ac *APIController) startWSHandler() { | ||||
| 					"build":        constants.BUILD(""), | ||||
| 				}).SetToCurrentTime() | ||||
| 			} | ||||
| 		case WebsocketInstructionProviderSpecific: | ||||
| 			for _, h := range ac.wsHandlers { | ||||
| 				h(context.Background(), wsMsg.Args) | ||||
| 		default: | ||||
| 			for _, h := range ac.eventHandlers { | ||||
| 				err := h(context.Background(), wsMsg) | ||||
| 				if err != nil { | ||||
| 					ac.logger.WithError(err).Warning("failed to run event handler") | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (ac *APIController) startWSHealth() { | ||||
| func (ac *APIController) startEventHealth() { | ||||
| 	ticker := time.NewTicker(time.Second * 10) | ||||
| 	for ; true; <-ticker.C { | ||||
| 		if ac.wsConn == nil { | ||||
| 			go ac.reconnectWS() | ||||
| 		if ac.eventConn == nil { | ||||
| 			go ac.recentEvents() | ||||
| 			time.Sleep(time.Second * 5) | ||||
| 			continue | ||||
| 		} | ||||
| 		err := ac.SendWSHello(map[string]interface{}{}) | ||||
| 		err := ac.SendEventHello(map[string]interface{}{}) | ||||
| 		if err != nil { | ||||
| 			ac.logger.WithField("loop", "ws-health").WithError(err).Warning("ws write error") | ||||
| 			go ac.reconnectWS() | ||||
| 			ac.logger.WithField("loop", "event-health").WithError(err).Warning("event write error") | ||||
| 			go ac.recentEvents() | ||||
| 			time.Sleep(time.Second * 5) | ||||
| 			continue | ||||
| 		} else { | ||||
| 			ac.logger.WithField("loop", "ws-health").Trace("hello'd") | ||||
| 			ac.logger.WithField("loop", "event-health").Trace("hello'd") | ||||
| 			ConnectionStatus.With(prometheus.Labels{ | ||||
| 				"outpost_name": ac.Outpost.Name, | ||||
| 				"outpost_type": ac.Server.Type(), | ||||
| @ -230,19 +236,19 @@ func (ac *APIController) startIntervalUpdater() { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (a *APIController) AddWSHandler(handler WSHandler) { | ||||
| 	a.wsHandlers = append(a.wsHandlers, handler) | ||||
| func (a *APIController) AddEventHandler(handler EventHandler) { | ||||
| 	a.eventHandlers = append(a.eventHandlers, handler) | ||||
| } | ||||
| 
 | ||||
| func (a *APIController) SendWSHello(args map[string]interface{}) error { | ||||
| 	allArgs := a.getWebsocketPingArgs() | ||||
| func (a *APIController) SendEventHello(args map[string]interface{}) error { | ||||
| 	allArgs := a.getEventPingArgs() | ||||
| 	for key, value := range args { | ||||
| 		allArgs[key] = value | ||||
| 	} | ||||
| 	aliveMsg := websocketMessage{ | ||||
| 		Instruction: WebsocketInstructionHello, | ||||
| 	aliveMsg := Event{ | ||||
| 		Instruction: EventKindHello, | ||||
| 		Args:        allArgs, | ||||
| 	} | ||||
| 	err := a.wsConn.WriteJSON(aliveMsg) | ||||
| 	err := a.eventConn.WriteJSON(aliveMsg) | ||||
| 	return err | ||||
| } | ||||
							
								
								
									
										37
									
								
								internal/outpost/ak/api_event_msg.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								internal/outpost/ak/api_event_msg.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,37 @@ | ||||
| package ak | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
|  | ||||
| 	"github.com/mitchellh/mapstructure" | ||||
| ) | ||||
|  | ||||
| type EventKind int | ||||
|  | ||||
| const ( | ||||
| 	// Code used to acknowledge a previous message | ||||
| 	EventKindAck EventKind = 0 | ||||
| 	// Code used to send a healthcheck keepalive | ||||
| 	EventKindHello EventKind = 1 | ||||
| 	// Code received to trigger a config update | ||||
| 	EventKindTriggerUpdate EventKind = 2 | ||||
| 	// Code received to trigger some provider specific function | ||||
| 	EventKindProviderSpecific EventKind = 3 | ||||
| 	// Code received to identify the end of a session | ||||
| 	EventKindSessionEnd EventKind = 4 | ||||
| ) | ||||
|  | ||||
| type EventHandler func(ctx context.Context, msg Event) error | ||||
|  | ||||
| type Event struct { | ||||
| 	Instruction EventKind   `json:"instruction"` | ||||
| 	Args        interface{} `json:"args"` | ||||
| } | ||||
|  | ||||
| func (wm Event) ArgsAs(out interface{}) error { | ||||
| 	return mapstructure.Decode(wm.Args, out) | ||||
| } | ||||
|  | ||||
| type EventArgsSessionEnd struct { | ||||
| 	SessionID string `mapstructure:"session_id"` | ||||
| } | ||||
| @ -15,7 +15,7 @@ func URLMustParse(u string) *url.URL { | ||||
| 	return ur | ||||
| } | ||||
| 
 | ||||
| func TestWebsocketURL(t *testing.T) { | ||||
| func TestEventWebsocketURL(t *testing.T) { | ||||
| 	u := URLMustParse("http://localhost:9000?foo=bar") | ||||
| 	uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77" | ||||
| 	ac := &APIController{} | ||||
| @ -23,7 +23,7 @@ func TestWebsocketURL(t *testing.T) { | ||||
| 	assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/?foo=bar", nu.String()) | ||||
| } | ||||
| 
 | ||||
| func TestWebsocketURL_Query(t *testing.T) { | ||||
| func TestEventWebsocketURL_Query(t *testing.T) { | ||||
| 	u := URLMustParse("http://localhost:9000?foo=bar") | ||||
| 	uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77" | ||||
| 	ac := &APIController{} | ||||
| @ -33,7 +33,7 @@ func TestWebsocketURL_Query(t *testing.T) { | ||||
| 	assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/?bar=baz&foo=bar", nu.String()) | ||||
| } | ||||
| 
 | ||||
| func TestWebsocketURL_Subpath(t *testing.T) { | ||||
| func TestEventWebsocketURL_Subpath(t *testing.T) { | ||||
| 	u := URLMustParse("http://localhost:9000/foo/bar/") | ||||
| 	uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77" | ||||
| 	ac := &APIController{} | ||||
| @ -1,19 +0,0 @@ | ||||
| package ak | ||||
|  | ||||
| type websocketInstruction int | ||||
|  | ||||
| const ( | ||||
| 	// WebsocketInstructionAck Code used to acknowledge a previous message | ||||
| 	WebsocketInstructionAck websocketInstruction = 0 | ||||
| 	// WebsocketInstructionHello Code used to send a healthcheck keepalive | ||||
| 	WebsocketInstructionHello websocketInstruction = 1 | ||||
| 	// WebsocketInstructionTriggerUpdate Code received to trigger a config update | ||||
| 	WebsocketInstructionTriggerUpdate websocketInstruction = 2 | ||||
| 	// WebsocketInstructionProviderSpecific Code received to trigger some provider specific function | ||||
| 	WebsocketInstructionProviderSpecific websocketInstruction = 3 | ||||
| ) | ||||
|  | ||||
| type websocketMessage struct { | ||||
| 	Instruction websocketInstruction   `json:"instruction"` | ||||
| 	Args        map[string]interface{} `json:"args"` | ||||
| } | ||||
| @ -55,11 +55,10 @@ func MockAK(outpost api.Outpost, globalConfig api.Config) *APIController { | ||||
| 		token:  token, | ||||
| 		logger: log, | ||||
|  | ||||
| 		reloadOffset:        time.Duration(rand.Intn(10)) * time.Second, | ||||
| 		instanceUUID:        uuid.New(), | ||||
| 		Outpost:             outpost, | ||||
| 		wsBackoffMultiplier: 1, | ||||
| 		refreshHandlers:     make([]func(), 0), | ||||
| 		reloadOffset:    time.Duration(rand.Intn(10)) * time.Second, | ||||
| 		instanceUUID:    uuid.New(), | ||||
| 		Outpost:         outpost, | ||||
| 		refreshHandlers: make([]func(), 0), | ||||
| 	} | ||||
| 	ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset") | ||||
| 	return ac | ||||
|  | ||||
| @ -127,7 +127,7 @@ func (fe *FlowExecutor) getAnswer(stage StageComponent) string { | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| func (fe *FlowExecutor) GetSession() *http.Cookie { | ||||
| func (fe *FlowExecutor) SessionCookie() *http.Cookie { | ||||
| 	return fe.session | ||||
| } | ||||
|  | ||||
|  | ||||
							
								
								
									
										19
									
								
								internal/outpost/flow/session.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								internal/outpost/flow/session.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,19 @@ | ||||
| package flow | ||||
|  | ||||
| import "github.com/golang-jwt/jwt/v5" | ||||
|  | ||||
| type SessionCookieClaims struct { | ||||
| 	jwt.Claims | ||||
|  | ||||
| 	SessionID     string `json:"sid"` | ||||
| 	Authenticated bool   `json:"authenticated"` | ||||
| } | ||||
|  | ||||
| func (fe *FlowExecutor) Session() *jwt.Token { | ||||
| 	sc := fe.SessionCookie() | ||||
| 	if sc == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 	t, _, _ := jwt.NewParser().ParseUnverified(sc.Value, &SessionCookieClaims{}) | ||||
| 	return t | ||||
| } | ||||
| @ -38,7 +38,14 @@ func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LD | ||||
| 		username, err := instance.binder.GetUsername(bindDN) | ||||
| 		if err == nil { | ||||
| 			selectedApp = instance.GetAppSlug() | ||||
| 			return instance.binder.Bind(username, req) | ||||
| 			c, err := instance.binder.Bind(username, req) | ||||
| 			if c == ldap.LDAPResultSuccess { | ||||
| 				f := instance.GetFlags(req.BindDN) | ||||
| 				ls.connectionsSync.Lock() | ||||
| 				ls.connections[f.SessionID()] = conn | ||||
| 				ls.connectionsSync.Unlock() | ||||
| 			} | ||||
| 			return c, err | ||||
| 		} else { | ||||
| 			req.Log().WithError(err).Debug("Username not for instance") | ||||
| 		} | ||||
|  | ||||
| @ -27,8 +27,9 @@ func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResul | ||||
|  | ||||
| 	passed, err := fe.Execute() | ||||
| 	flags := flags.UserFlags{ | ||||
| 		Session: fe.GetSession(), | ||||
| 		UserPk:  flags.InvalidUserPK, | ||||
| 		Session:    fe.SessionCookie(), | ||||
| 		SessionJWT: fe.Session(), | ||||
| 		UserPk:     flags.InvalidUserPK, | ||||
| 	} | ||||
| 	// only set flags if we don't have flags for this DN yet | ||||
| 	// as flags are only checked during the bind, we can remember whether a certain DN | ||||
|  | ||||
							
								
								
									
										20
									
								
								internal/outpost/ldap/close.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								internal/outpost/ldap/close.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,20 @@ | ||||
| package ldap | ||||
|  | ||||
| import "net" | ||||
|  | ||||
| func (ls *LDAPServer) Close(dn string, conn net.Conn) error { | ||||
| 	ls.connectionsSync.Lock() | ||||
| 	defer ls.connectionsSync.Unlock() | ||||
| 	key := "" | ||||
| 	for k, c := range ls.connections { | ||||
| 		if c == conn { | ||||
| 			key = k | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	if key == "" { | ||||
| 		return nil | ||||
| 	} | ||||
| 	delete(ls.connections, key) | ||||
| 	return nil | ||||
| } | ||||
| @ -1,16 +1,30 @@ | ||||
| package flags | ||||
|  | ||||
| import ( | ||||
| 	"crypto/sha256" | ||||
| 	"encoding/hex" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
| 	"goauthentik.io/api/v3" | ||||
| 	"goauthentik.io/internal/outpost/flow" | ||||
| ) | ||||
|  | ||||
| const InvalidUserPK = -1 | ||||
|  | ||||
| type UserFlags struct { | ||||
| 	UserInfo  *api.User | ||||
| 	UserPk    int32 | ||||
| 	CanSearch bool | ||||
| 	Session   *http.Cookie | ||||
| 	UserInfo   *api.User | ||||
| 	UserPk     int32 | ||||
| 	CanSearch  bool | ||||
| 	Session    *http.Cookie | ||||
| 	SessionJWT *jwt.Token | ||||
| } | ||||
|  | ||||
| func (uf UserFlags) SessionID() string { | ||||
| 	if uf.SessionJWT == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	h := sha256.New() | ||||
| 	h.Write([]byte(uf.SessionJWT.Claims.(*flow.SessionCookieClaims).SessionID)) | ||||
| 	return hex.EncodeToString(h.Sum(nil)) | ||||
| } | ||||
|  | ||||
| @ -18,21 +18,26 @@ import ( | ||||
| ) | ||||
|  | ||||
| type LDAPServer struct { | ||||
| 	s           *ldap.Server | ||||
| 	log         *log.Entry | ||||
| 	ac          *ak.APIController | ||||
| 	cs          *ak.CryptoStore | ||||
| 	defaultCert *tls.Certificate | ||||
| 	providers   []*ProviderInstance | ||||
| 	s               *ldap.Server | ||||
| 	log             *log.Entry | ||||
| 	ac              *ak.APIController | ||||
| 	cs              *ak.CryptoStore | ||||
| 	defaultCert     *tls.Certificate | ||||
| 	providers       []*ProviderInstance | ||||
| 	connections     map[string]net.Conn | ||||
| 	connectionsSync sync.Mutex | ||||
| } | ||||
|  | ||||
| func NewServer(ac *ak.APIController) ak.Outpost { | ||||
| 	ls := &LDAPServer{ | ||||
| 		log:       log.WithField("logger", "authentik.outpost.ldap"), | ||||
| 		ac:        ac, | ||||
| 		cs:        ak.NewCryptoStore(ac.Client.CryptoApi), | ||||
| 		providers: []*ProviderInstance{}, | ||||
| 		log:             log.WithField("logger", "authentik.outpost.ldap"), | ||||
| 		ac:              ac, | ||||
| 		cs:              ak.NewCryptoStore(ac.Client.CryptoApi), | ||||
| 		providers:       []*ProviderInstance{}, | ||||
| 		connections:     map[string]net.Conn{}, | ||||
| 		connectionsSync: sync.Mutex{}, | ||||
| 	} | ||||
| 	ac.AddEventHandler(ls.handleWSSessionEnd) | ||||
| 	s := ldap.NewServer() | ||||
| 	s.EnforceLDAP = true | ||||
|  | ||||
| @ -50,6 +55,7 @@ func NewServer(ac *ak.APIController) ak.Outpost { | ||||
| 	s.BindFunc("", ls) | ||||
| 	s.UnbindFunc("", ls) | ||||
| 	s.SearchFunc("", ls) | ||||
| 	s.CloseFunc("", ls) | ||||
| 	return ls | ||||
| } | ||||
|  | ||||
| @ -117,3 +123,23 @@ func (ls *LDAPServer) TimerFlowCacheExpiry(ctx context.Context) { | ||||
| 		p.binder.TimerFlowCacheExpiry(ctx) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (ls *LDAPServer) handleWSSessionEnd(ctx context.Context, msg ak.Event) error { | ||||
| 	if msg.Instruction != ak.EventKindSessionEnd { | ||||
| 		return nil | ||||
| 	} | ||||
| 	mmsg := ak.EventArgsSessionEnd{} | ||||
| 	err := msg.ArgsAs(&mmsg) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	ls.connectionsSync.Lock() | ||||
| 	defer ls.connectionsSync.Unlock() | ||||
| 	ls.log.Info("Disconnecting session due to session end event") | ||||
| 	conn, ok := ls.connections[mmsg.SessionID] | ||||
| 	if !ok { | ||||
| 		return nil | ||||
| 	} | ||||
| 	delete(ls.connections, mmsg.SessionID) | ||||
| 	return conn.Close() | ||||
| } | ||||
|  | ||||
| @ -44,38 +44,40 @@ func (ds *DirectSearcher) SearchSubschema(req *search.Request) (ldap.ServerSearc | ||||
| 					{ | ||||
| 						Name: "attributeTypes", | ||||
| 						Values: []string{ | ||||
| 							"( 2.5.4.0 NAME 'objectClass' SYNTAX '1.3.6.1.4.1.1466.115.121.1.38' NO-USER-MODIFICATION )", | ||||
| 							"( 2.5.4.4 NAME 'sn' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 2.5.4.3 NAME 'cn' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 2.5.4.6 NAME 'c' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 2.5.4.7 NAME 'l' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 2.5.4.10 NAME 'o' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )", | ||||
| 							"( 2.5.4.11 NAME 'ou' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )", | ||||
| 							"( 2.5.4.12 NAME 'title' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 2.5.4.13 NAME 'description' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )", | ||||
| 							"( 2.5.4.20 NAME 'telephoneNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 2.5.4.31 NAME 'member' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' )", | ||||
| 							"( 2.5.4.42 NAME 'givenName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 2.5.21.2 NAME 'dITContentRules' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )", | ||||
| 							"( 2.5.21.5 NAME 'attributeTypes' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )", | ||||
| 							"( 2.5.21.6 NAME 'objectClasses' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )", | ||||
| 							"( 0.9.2342.19200300.100.1.1 NAME 'uid' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 0.9.2342.19200300.100.1.3 NAME 'mail' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 0.9.2342.19200300.100.1.41 NAME 'mobile' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 1.2.840.113556.1.2.13 NAME 'displayName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 1.2.840.113556.1.2.146 NAME 'company' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 1.2.840.113556.1.2.102 NAME 'memberOf' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' NO-USER-MODIFICATION )", | ||||
| 							"( 1.2.840.113556.1.2.13 NAME 'displayName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 1.2.840.113556.1.2.131 NAME 'co' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 1.2.840.113556.1.2.141 NAME 'department' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 1.2.840.113556.1.2.146 NAME 'company' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 1.2.840.113556.1.4.1 NAME 'name' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE NO-USER-MODIFICATION )", | ||||
| 							"( 1.2.840.113556.1.4.44 NAME 'homeDirectory' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 1.2.840.113556.1.4.221 NAME 'sAMAccountName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 1.2.840.113556.1.4.261 NAME 'division' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 1.2.840.113556.1.4.44 NAME 'homeDirectory' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 1.2.840.113556.1.4.750 NAME 'groupType' SYNTAX '1.3.6.1.4.1.1466.115.121.1.27' SINGLE-VALUE )", | ||||
| 							"( 1.2.840.113556.1.4.782 NAME 'objectCategory' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' SINGLE-VALUE )", | ||||
| 							"( 1.3.6.1.1.1.1.0 NAME 'uidNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.27' SINGLE-VALUE )", | ||||
| 							"( 1.3.6.1.1.1.1.1 NAME 'gidNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.27' SINGLE-VALUE )", | ||||
| 							"( 1.3.6.1.1.1.1.12 NAME 'memberUid' SYNTAX '1.3.6.1.4.1.1466.115.121.1.26' )", | ||||
| 							"( 2.5.18.1 NAME 'createTimestamp' SYNTAX 1.3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE NO-USER-MODIFICATION )", | ||||
| 							"( 2.5.18.2 NAME 'modifyTimestamp' SYNTAX 1.3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE NO-USER-MODIFICATION )", | ||||
| 							"( 2.5.21.2 NAME 'dITContentRules' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )", | ||||
| 							"( 2.5.21.5 NAME 'attributeTypes' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )", | ||||
| 							"( 2.5.21.6 NAME 'objectClasses' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )", | ||||
| 							"( 2.5.4.0 NAME 'objectClass' SYNTAX '1.3.6.1.4.1.1466.115.121.1.38' NO-USER-MODIFICATION )", | ||||
| 							"( 2.5.4.10 NAME 'o' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )", | ||||
| 							"( 2.5.4.11 NAME 'ou' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )", | ||||
| 							"( 2.5.4.12 NAME 'title' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 2.5.4.13 NAME 'description' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )", | ||||
| 							"( 2.5.4.20 NAME 'telephoneNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 2.5.4.3 NAME 'cn' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 2.5.4.31 NAME 'member' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' )", | ||||
| 							"( 2.5.4.4 NAME 'sn' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 2.5.4.42 NAME 'givenName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 2.5.4.6 NAME 'c' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
| 							"( 2.5.4.7 NAME 'l' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", | ||||
|  | ||||
| 							// Custom attributes | ||||
| 							// Temporarily use 1.3.6.1.4.1.26027.1.1 as a base | ||||
|  | ||||
| @ -66,7 +66,7 @@ func NewProxyServer(ac *ak.APIController) ak.Outpost { | ||||
| 	globalMux.PathPrefix("/outpost.goauthentik.io/static").HandlerFunc(s.HandleStatic) | ||||
| 	globalMux.Path("/outpost.goauthentik.io/ping").HandlerFunc(sentryutils.SentryNoSample(s.HandlePing)) | ||||
| 	rootMux.PathPrefix("/").HandlerFunc(s.Handle) | ||||
| 	ac.AddWSHandler(s.handleWSMessage) | ||||
| 	ac.AddEventHandler(s.handleWSMessage) | ||||
| 	return s | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -3,48 +3,27 @@ package proxyv2 | ||||
| import ( | ||||
| 	"context" | ||||
|  | ||||
| 	"github.com/mitchellh/mapstructure" | ||||
| 	"goauthentik.io/internal/outpost/ak" | ||||
| 	"goauthentik.io/internal/outpost/proxyv2/application" | ||||
| ) | ||||
|  | ||||
| type WSProviderSubType string | ||||
|  | ||||
| const ( | ||||
| 	WSProviderSubTypeLogout WSProviderSubType = "logout" | ||||
| ) | ||||
|  | ||||
| type WSProviderMsg struct { | ||||
| 	SubType   WSProviderSubType `mapstructure:"sub_type"` | ||||
| 	SessionID string            `mapstructure:"session_id"` | ||||
| } | ||||
|  | ||||
| func ParseWSProvider(args map[string]interface{}) (*WSProviderMsg, error) { | ||||
| 	msg := &WSProviderMsg{} | ||||
| 	err := mapstructure.Decode(args, &msg) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| func (ps *ProxyServer) handleWSMessage(ctx context.Context, msg ak.Event) error { | ||||
| 	if msg.Instruction != ak.EventKindSessionEnd { | ||||
| 		return nil | ||||
| 	} | ||||
| 	return msg, nil | ||||
| } | ||||
|  | ||||
| func (ps *ProxyServer) handleWSMessage(ctx context.Context, args map[string]interface{}) { | ||||
| 	msg, err := ParseWSProvider(args) | ||||
| 	mmsg := ak.EventArgsSessionEnd{} | ||||
| 	err := msg.ArgsAs(&mmsg) | ||||
| 	if err != nil { | ||||
| 		ps.log.WithError(err).Warning("invalid provider-specific ws message") | ||||
| 		return | ||||
| 		return err | ||||
| 	} | ||||
| 	switch msg.SubType { | ||||
| 	case WSProviderSubTypeLogout: | ||||
| 		for _, p := range ps.apps { | ||||
| 			ps.log.WithField("provider", p.Host).Debug("Logging out") | ||||
| 			err := p.Logout(ctx, func(c application.Claims) bool { | ||||
| 				return c.Sid == msg.SessionID | ||||
| 			}) | ||||
| 			if err != nil { | ||||
| 				ps.log.WithField("provider", p.Host).WithError(err).Warning("failed to logout") | ||||
| 			} | ||||
| 	for _, p := range ps.apps { | ||||
| 		ps.log.WithField("provider", p.Host).Debug("Logging out") | ||||
| 		err := p.Logout(ctx, func(c application.Claims) bool { | ||||
| 			return c.Sid == mmsg.SessionID | ||||
| 		}) | ||||
| 		if err != nil { | ||||
| 			ps.log.WithField("provider", p.Host).WithError(err).Warning("failed to logout") | ||||
| 		} | ||||
| 	default: | ||||
| 		ps.log.WithField("sub_type", msg.SubType).Warning("invalid sub_type") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @ -6,7 +6,6 @@ import ( | ||||
| 	"strconv" | ||||
| 	"sync" | ||||
|  | ||||
| 	"github.com/mitchellh/mapstructure" | ||||
| 	log "github.com/sirupsen/logrus" | ||||
| 	"github.com/wwt/guac" | ||||
|  | ||||
| @ -30,7 +29,7 @@ func NewServer(ac *ak.APIController) ak.Outpost { | ||||
| 		connm: sync.RWMutex{}, | ||||
| 		conns: map[string]connection.Connection{}, | ||||
| 	} | ||||
| 	ac.AddWSHandler(rs.wsHandler) | ||||
| 	ac.AddEventHandler(rs.wsHandler) | ||||
| 	return rs | ||||
| } | ||||
|  | ||||
| @ -52,12 +51,14 @@ func parseIntOrZero(input string) int { | ||||
| 	return x | ||||
| } | ||||
|  | ||||
| func (rs *RACServer) wsHandler(ctx context.Context, args map[string]interface{}) { | ||||
| func (rs *RACServer) wsHandler(ctx context.Context, msg ak.Event) error { | ||||
| 	if msg.Instruction != ak.EventKindProviderSpecific { | ||||
| 		return nil | ||||
| 	} | ||||
| 	wsm := WSMessage{} | ||||
| 	err := mapstructure.Decode(args, &wsm) | ||||
| 	err := msg.ArgsAs(&wsm) | ||||
| 	if err != nil { | ||||
| 		rs.log.WithError(err).Warning("invalid ws message") | ||||
| 		return | ||||
| 		return err | ||||
| 	} | ||||
| 	config := guac.NewGuacamoleConfiguration() | ||||
| 	config.Protocol = wsm.Protocol | ||||
| @ -71,23 +72,23 @@ func (rs *RACServer) wsHandler(ctx context.Context, args map[string]interface{}) | ||||
| 	} | ||||
| 	cc, err := connection.NewConnection(rs.ac, wsm.DestChannelID, config) | ||||
| 	if err != nil { | ||||
| 		rs.log.WithError(err).Warning("failed to setup connection") | ||||
| 		return | ||||
| 		return err | ||||
| 	} | ||||
| 	cc.OnError = func(err error) { | ||||
| 		rs.connm.Lock() | ||||
| 		delete(rs.conns, wsm.ConnID) | ||||
| 		_ = rs.ac.SendWSHello(map[string]interface{}{ | ||||
| 		_ = rs.ac.SendEventHello(map[string]interface{}{ | ||||
| 			"active_connections": len(rs.conns), | ||||
| 		}) | ||||
| 		rs.connm.Unlock() | ||||
| 	} | ||||
| 	rs.connm.Lock() | ||||
| 	rs.conns[wsm.ConnID] = *cc | ||||
| 	_ = rs.ac.SendWSHello(map[string]interface{}{ | ||||
| 	_ = rs.ac.SendEventHello(map[string]interface{}{ | ||||
| 		"active_connections": len(rs.conns), | ||||
| 	}) | ||||
| 	rs.connm.Unlock() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (rs *RACServer) Start() error { | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	 Jens L.
					Jens L.