outposts: Refactor session end signal and add LDAP support (#14539)

* outpost: promote session end signal to non-provider specific

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* implement server-side logout in ldap

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix previous import

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* use better retry logic

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* log

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* make more generic if we switch from ws to something else

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* make it possible to e2e test WS

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix ldap session id

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* ok I actually need to go to bed this took me an hour to fix

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* format; add ldap test

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix leftover state

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* remove thread

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* use ws base for radius

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* separate test utils

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* rename

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix missing super calls

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* websocket tests with browser 🎉

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add proxy test for sign out

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix install_id issue with channels tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix proxy basic auth test

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* big code dedupe

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* allow passing go build args

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* improve waiting for outpost

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* rewrite ldap tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* ok actually fix the tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* undo a couple things that need more time to cook

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* remove unused lockfile-lint dependency since we use a shell script and SFE does not have a lockfile

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix session id for ldap

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix missing createTimestamp and modifyTimestamp ldap attributes

closes #10474

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L.
2025-06-10 12:11:21 +02:00
committed by GitHub
parent 8bfc9ab7c9
commit 88fa7e37dc
25 changed files with 347 additions and 249 deletions

View File

@ -11,3 +11,4 @@ blueprints/local
!gen-ts-api/node_modules
!gen-ts-api/dist/**
!gen-go-api/
.venv

View File

@ -37,6 +37,9 @@ class WebsocketMessageInstruction(IntEnum):
# Provider specific message
PROVIDER_SPECIFIC = 3
# Session ended
SESSION_END = 4
@dataclass(slots=True)
class WebsocketMessage:
@ -145,6 +148,14 @@ class OutpostConsumer(JsonWebsocketConsumer):
asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE))
)
def event_session_end(self, event):
"""Event handler which is called when a session is ended"""
self.send_json(
asdict(
WebsocketMessage(instruction=WebsocketMessageInstruction.SESSION_END, args=event)
)
)
def event_provider_specific(self, event):
"""Event handler which can be called by provider-specific
implementations to send specific messages to the outpost"""

View File

@ -1,17 +1,24 @@
"""authentik outpost signals"""
from django.contrib.auth.signals import user_logged_out
from django.core.cache import cache
from django.db.models import Model
from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save
from django.dispatch import receiver
from django.http import HttpRequest
from structlog.stdlib import get_logger
from authentik.brands.models import Brand
from authentik.core.models import Provider
from authentik.core.models import AuthenticatedSession, Provider, User
from authentik.crypto.models import CertificateKeyPair
from authentik.lib.utils.reflection import class_to_path
from authentik.outposts.models import Outpost, OutpostServiceConnection
from authentik.outposts.tasks import CACHE_KEY_OUTPOST_DOWN, outpost_controller, outpost_post_save
from authentik.outposts.tasks import (
CACHE_KEY_OUTPOST_DOWN,
outpost_controller,
outpost_post_save,
outpost_session_end,
)
LOGGER = get_logger()
UPDATE_TRIGGERING_MODELS = (
@ -73,3 +80,17 @@ def pre_delete_cleanup(sender, instance: Outpost, **_):
instance.user.delete()
cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance)
outpost_controller.delay(instance.pk.hex, action="down", from_cache=True)
@receiver(user_logged_out)
def logout_revoke_direct(sender: type[User], request: HttpRequest, **_):
"""Catch logout by direct logout and forward to providers"""
if not request.session or not request.session.session_key:
return
outpost_session_end.delay(request.session.session_key)
@receiver(pre_delete, sender=AuthenticatedSession)
def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
"""Catch logout by expiring sessions being deleted"""
outpost_session_end.delay(instance.session.session_key)

View File

@ -1,5 +1,6 @@
"""outpost tasks"""
from hashlib import sha256
from os import R_OK, access
from pathlib import Path
from socket import gethostname
@ -49,6 +50,11 @@ LOGGER = get_logger()
CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s"
def hash_session_key(session_key: str) -> str:
"""Hash the session key for sending session end signals"""
return sha256(session_key.encode("ascii")).hexdigest()
def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None:
"""Get a controller for the outpost, when a service connection is defined"""
if not outpost.service_connection:
@ -289,3 +295,20 @@ def outpost_connection_discovery(self: SystemTask):
url=unix_socket_path,
)
self.set_status(TaskStatus.SUCCESSFUL, *messages)
@CELERY_APP.task()
def outpost_session_end(session_id: str):
"""Update outpost instances connected to a single outpost"""
layer = get_channel_layer()
hashed_session_id = hash_session_key(session_id)
for outpost in Outpost.objects.all():
LOGGER.info("Sending session end signal to outpost", outpost=outpost)
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
async_to_sync(layer.group_send)(
group,
{
"type": "event.session.end",
"session_id": hashed_session_id,
},
)

View File

@ -1,23 +0,0 @@
"""Proxy provider signals"""
from django.contrib.auth.signals import user_logged_out
from django.db.models.signals import pre_delete
from django.dispatch import receiver
from django.http import HttpRequest
from authentik.core.models import AuthenticatedSession, User
from authentik.providers.proxy.tasks import proxy_on_logout
@receiver(user_logged_out)
def logout_proxy_revoke_direct(sender: type[User], request: HttpRequest, **_):
"""Catch logout by direct logout and forward to proxy providers"""
if not request.session or not request.session.session_key:
return
proxy_on_logout.delay(request.session.session_key)
@receiver(pre_delete, sender=AuthenticatedSession)
def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
"""Catch logout by expiring sessions being deleted"""
proxy_on_logout.delay(instance.session.session_key)

View File

@ -1,26 +0,0 @@
"""proxy provider tasks"""
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from authentik.outposts.consumer import OUTPOST_GROUP
from authentik.outposts.models import Outpost, OutpostType
from authentik.providers.oauth2.id_token import hash_session_key
from authentik.root.celery import CELERY_APP
@CELERY_APP.task()
def proxy_on_logout(session_id: str):
"""Update outpost instances connected to a single outpost"""
layer = get_channel_layer()
hashed_session_id = hash_session_key(session_id)
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
async_to_sync(layer.group_send)(
group,
{
"type": "event.provider.specific",
"sub_type": "logout",
"session_id": hashed_session_id,
},
)

1
go.mod
View File

@ -4,6 +4,7 @@ go 1.24.0
require (
beryju.io/ldap v0.1.0
github.com/avast/retry-go/v4 v4.6.1
github.com/coreos/go-oidc/v3 v3.14.1
github.com/getsentry/sentry-go v0.33.0
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1

2
go.sum
View File

@ -41,6 +41,8 @@ github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7V
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so=
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
github.com/avast/retry-go/v4 v4.6.1 h1:VkOLRubHdisGrHnTu89g08aQEWEgRU7LVEop3GbIcMk=
github.com/avast/retry-go/v4 v4.6.1/go.mod h1:V6oF8njAwxJ5gRo1Q7Cxab24xs5NCWZBeaHHBklR8mA=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=

View File

@ -13,6 +13,7 @@ import (
"syscall"
"time"
"github.com/avast/retry-go/v4"
"github.com/getsentry/sentry-go"
"github.com/google/uuid"
"github.com/gorilla/websocket"
@ -25,8 +26,6 @@ import (
"goauthentik.io/internal/utils/web"
)
type WSHandler func(ctx context.Context, args map[string]interface{})
const ConfigLogLevel = "log_level"
// APIController main controller which connects to the authentik api via http and ws
@ -43,12 +42,11 @@ type APIController struct {
reloadOffset time.Duration
wsConn *websocket.Conn
lastWsReconnect time.Time
wsIsReconnecting bool
wsBackoffMultiplier int
wsHandlers []WSHandler
refreshHandlers []func()
eventConn *websocket.Conn
lastWsReconnect time.Time
wsIsReconnecting bool
eventHandlers []EventHandler
refreshHandlers []func()
instanceUUID uuid.UUID
}
@ -83,20 +81,19 @@ func NewAPIController(akURL url.URL, token string) *APIController {
// Because we don't know the outpost UUID, we simply do a list and pick the first
// The service account this token belongs to should only have access to a single outpost
var outposts *api.PaginatedOutpostList
var err error
for {
outposts, _, err = apiClient.OutpostsApi.OutpostsInstancesList(context.Background()).Execute()
if err == nil {
break
}
log.WithError(err).Error("Failed to fetch outpost configuration, retrying in 3 seconds")
time.Sleep(time.Second * 3)
}
outposts, _ := retry.DoWithData[*api.PaginatedOutpostList](
func() (*api.PaginatedOutpostList, error) {
outposts, _, err := apiClient.OutpostsApi.OutpostsInstancesList(context.Background()).Execute()
return outposts, err
},
retry.Attempts(0),
retry.Delay(time.Second*3),
retry.OnRetry(func(attempt uint, err error) {
log.WithError(err).Error("Failed to fetch outpost configuration, retrying in 3 seconds")
}),
)
if len(outposts.Results) < 1 {
panic("No outposts found with given token, ensure the given token corresponds to an authenitk Outpost")
log.Panic("No outposts found with given token, ensure the given token corresponds to an authenitk Outpost")
}
outpost := outposts.Results[0]
@ -119,17 +116,16 @@ func NewAPIController(akURL url.URL, token string) *APIController {
token: token,
logger: log,
reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
instanceUUID: uuid.New(),
Outpost: outpost,
wsHandlers: []WSHandler{},
wsBackoffMultiplier: 1,
refreshHandlers: make([]func(), 0),
reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
instanceUUID: uuid.New(),
Outpost: outpost,
eventHandlers: []EventHandler{},
refreshHandlers: make([]func(), 0),
}
ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset")
err = ac.initWS(akURL, outpost.Pk)
err = ac.initEvent(akURL, outpost.Pk)
if err != nil {
go ac.reconnectWS()
go ac.recentEvents()
}
ac.configureRefreshSignal()
return ac
@ -200,7 +196,7 @@ func (a *APIController) OnRefresh() error {
return err
}
func (a *APIController) getWebsocketPingArgs() map[string]interface{} {
func (a *APIController) getEventPingArgs() map[string]interface{} {
args := map[string]interface{}{
"version": constants.VERSION,
"buildHash": constants.BUILD(""),
@ -226,12 +222,12 @@ func (a *APIController) StartBackgroundTasks() error {
"build": constants.BUILD(""),
}).Set(1)
go func() {
a.logger.Debug("Starting WS Handler...")
a.startWSHandler()
a.logger.Debug("Starting Event Handler...")
a.startEventHandler()
}()
go func() {
a.logger.Debug("Starting WS Health notifier...")
a.startWSHealth()
a.logger.Debug("Starting Event health notifier...")
a.startEventHealth()
}()
go func() {
a.logger.Debug("Starting Interval updater...")

View File

@ -11,6 +11,7 @@ import (
"strings"
"time"
"github.com/avast/retry-go/v4"
"github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus"
"goauthentik.io/internal/config"
@ -30,7 +31,7 @@ func (ac *APIController) getWebsocketURL(akURL url.URL, outpostUUID string, quer
return wsUrl
}
func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
func (ac *APIController) initEvent(akURL url.URL, outpostUUID string) error {
query := akURL.Query()
query.Set("instance_uuid", ac.instanceUUID.String())
@ -57,19 +58,19 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
return err
}
ac.wsConn = ws
ac.eventConn = ws
// Send hello message with our version
msg := websocketMessage{
Instruction: WebsocketInstructionHello,
Args: ac.getWebsocketPingArgs(),
msg := Event{
Instruction: EventKindHello,
Args: ac.getEventPingArgs(),
}
err = ws.WriteJSON(msg)
if err != nil {
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithError(err).Warning("Failed to hello to authentik")
ac.logger.WithField("logger", "authentik.outpost.events").WithError(err).Warning("Failed to hello to authentik")
return err
}
ac.lastWsReconnect = time.Now()
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID).Info("Successfully connected websocket")
ac.logger.WithField("logger", "authentik.outpost.events").WithField("outpost", outpostUUID).Info("Successfully connected websocket")
return nil
}
@ -77,19 +78,19 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
func (ac *APIController) Shutdown() {
// Cleanly close the connection by sending a close message and then
// waiting (with timeout) for the server to close the connection.
err := ac.wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
err := ac.eventConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
ac.logger.WithError(err).Warning("failed to write close message")
return
}
err = ac.wsConn.Close()
err = ac.eventConn.Close()
if err != nil {
ac.logger.WithError(err).Warning("failed to close websocket")
}
ac.logger.Info("finished shutdown")
}
func (ac *APIController) reconnectWS() {
func (ac *APIController) recentEvents() {
if ac.wsIsReconnecting {
return
}
@ -100,46 +101,47 @@ func (ac *APIController) reconnectWS() {
Path: strings.ReplaceAll(ac.Client.GetConfig().Servers[0].URL, "api/v3", ""),
}
attempt := 1
for {
q := u.Query()
q.Set("attempt", strconv.Itoa(attempt))
u.RawQuery = q.Encode()
err := ac.initWS(u, ac.Outpost.Pk)
attempt += 1
if err != nil {
ac.logger.Infof("waiting %d seconds to reconnect", ac.wsBackoffMultiplier)
time.Sleep(time.Duration(ac.wsBackoffMultiplier) * time.Second)
ac.wsBackoffMultiplier = ac.wsBackoffMultiplier * 2
// Limit to 300 seconds (5m)
if ac.wsBackoffMultiplier >= 300 {
ac.wsBackoffMultiplier = 300
_ = retry.Do(
func() error {
q := u.Query()
q.Set("attempt", strconv.Itoa(attempt))
u.RawQuery = q.Encode()
err := ac.initEvent(u, ac.Outpost.Pk)
attempt += 1
if err != nil {
return err
}
} else {
ac.wsIsReconnecting = false
ac.wsBackoffMultiplier = 1
return
}
}
return nil
},
retry.Delay(1*time.Second),
retry.MaxDelay(5*time.Minute),
retry.DelayType(retry.BackOffDelay),
retry.Attempts(0),
retry.OnRetry(func(attempt uint, err error) {
ac.logger.Infof("waiting %d seconds to reconnect", attempt)
}),
)
}
func (ac *APIController) startWSHandler() {
logger := ac.logger.WithField("loop", "ws-handler")
func (ac *APIController) startEventHandler() {
logger := ac.logger.WithField("loop", "event-handler")
for {
var wsMsg websocketMessage
if ac.wsConn == nil {
go ac.reconnectWS()
var wsMsg Event
if ac.eventConn == nil {
go ac.recentEvents()
time.Sleep(time.Second * 5)
continue
}
err := ac.wsConn.ReadJSON(&wsMsg)
err := ac.eventConn.ReadJSON(&wsMsg)
if err != nil {
ConnectionStatus.With(prometheus.Labels{
"outpost_name": ac.Outpost.Name,
"outpost_type": ac.Server.Type(),
"uuid": ac.instanceUUID.String(),
}).Set(0)
logger.WithError(err).Warning("ws read error")
go ac.reconnectWS()
logger.WithError(err).Warning("event read error")
go ac.recentEvents()
time.Sleep(time.Second * 5)
continue
}
@ -149,7 +151,8 @@ func (ac *APIController) startWSHandler() {
"uuid": ac.instanceUUID.String(),
}).Set(1)
switch wsMsg.Instruction {
case WebsocketInstructionTriggerUpdate:
case EventKindAck:
case EventKindTriggerUpdate:
time.Sleep(ac.reloadOffset)
logger.Debug("Got update trigger...")
err := ac.OnRefresh()
@ -164,30 +167,33 @@ func (ac *APIController) startWSHandler() {
"build": constants.BUILD(""),
}).SetToCurrentTime()
}
case WebsocketInstructionProviderSpecific:
for _, h := range ac.wsHandlers {
h(context.Background(), wsMsg.Args)
default:
for _, h := range ac.eventHandlers {
err := h(context.Background(), wsMsg)
if err != nil {
ac.logger.WithError(err).Warning("failed to run event handler")
}
}
}
}
}
func (ac *APIController) startWSHealth() {
func (ac *APIController) startEventHealth() {
ticker := time.NewTicker(time.Second * 10)
for ; true; <-ticker.C {
if ac.wsConn == nil {
go ac.reconnectWS()
if ac.eventConn == nil {
go ac.recentEvents()
time.Sleep(time.Second * 5)
continue
}
err := ac.SendWSHello(map[string]interface{}{})
err := ac.SendEventHello(map[string]interface{}{})
if err != nil {
ac.logger.WithField("loop", "ws-health").WithError(err).Warning("ws write error")
go ac.reconnectWS()
ac.logger.WithField("loop", "event-health").WithError(err).Warning("event write error")
go ac.recentEvents()
time.Sleep(time.Second * 5)
continue
} else {
ac.logger.WithField("loop", "ws-health").Trace("hello'd")
ac.logger.WithField("loop", "event-health").Trace("hello'd")
ConnectionStatus.With(prometheus.Labels{
"outpost_name": ac.Outpost.Name,
"outpost_type": ac.Server.Type(),
@ -230,19 +236,19 @@ func (ac *APIController) startIntervalUpdater() {
}
}
func (a *APIController) AddWSHandler(handler WSHandler) {
a.wsHandlers = append(a.wsHandlers, handler)
func (a *APIController) AddEventHandler(handler EventHandler) {
a.eventHandlers = append(a.eventHandlers, handler)
}
func (a *APIController) SendWSHello(args map[string]interface{}) error {
allArgs := a.getWebsocketPingArgs()
func (a *APIController) SendEventHello(args map[string]interface{}) error {
allArgs := a.getEventPingArgs()
for key, value := range args {
allArgs[key] = value
}
aliveMsg := websocketMessage{
Instruction: WebsocketInstructionHello,
aliveMsg := Event{
Instruction: EventKindHello,
Args: allArgs,
}
err := a.wsConn.WriteJSON(aliveMsg)
err := a.eventConn.WriteJSON(aliveMsg)
return err
}

View File

@ -0,0 +1,37 @@
package ak
import (
"context"
"github.com/mitchellh/mapstructure"
)
type EventKind int
const (
// Code used to acknowledge a previous message
EventKindAck EventKind = 0
// Code used to send a healthcheck keepalive
EventKindHello EventKind = 1
// Code received to trigger a config update
EventKindTriggerUpdate EventKind = 2
// Code received to trigger some provider specific function
EventKindProviderSpecific EventKind = 3
// Code received to identify the end of a session
EventKindSessionEnd EventKind = 4
)
type EventHandler func(ctx context.Context, msg Event) error
type Event struct {
Instruction EventKind `json:"instruction"`
Args interface{} `json:"args"`
}
func (wm Event) ArgsAs(out interface{}) error {
return mapstructure.Decode(wm.Args, out)
}
type EventArgsSessionEnd struct {
SessionID string `mapstructure:"session_id"`
}

View File

@ -15,7 +15,7 @@ func URLMustParse(u string) *url.URL {
return ur
}
func TestWebsocketURL(t *testing.T) {
func TestEventWebsocketURL(t *testing.T) {
u := URLMustParse("http://localhost:9000?foo=bar")
uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77"
ac := &APIController{}
@ -23,7 +23,7 @@ func TestWebsocketURL(t *testing.T) {
assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/?foo=bar", nu.String())
}
func TestWebsocketURL_Query(t *testing.T) {
func TestEventWebsocketURL_Query(t *testing.T) {
u := URLMustParse("http://localhost:9000?foo=bar")
uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77"
ac := &APIController{}
@ -33,7 +33,7 @@ func TestWebsocketURL_Query(t *testing.T) {
assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/?bar=baz&foo=bar", nu.String())
}
func TestWebsocketURL_Subpath(t *testing.T) {
func TestEventWebsocketURL_Subpath(t *testing.T) {
u := URLMustParse("http://localhost:9000/foo/bar/")
uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77"
ac := &APIController{}

View File

@ -1,19 +0,0 @@
package ak
type websocketInstruction int
const (
// WebsocketInstructionAck Code used to acknowledge a previous message
WebsocketInstructionAck websocketInstruction = 0
// WebsocketInstructionHello Code used to send a healthcheck keepalive
WebsocketInstructionHello websocketInstruction = 1
// WebsocketInstructionTriggerUpdate Code received to trigger a config update
WebsocketInstructionTriggerUpdate websocketInstruction = 2
// WebsocketInstructionProviderSpecific Code received to trigger some provider specific function
WebsocketInstructionProviderSpecific websocketInstruction = 3
)
type websocketMessage struct {
Instruction websocketInstruction `json:"instruction"`
Args map[string]interface{} `json:"args"`
}

View File

@ -55,11 +55,10 @@ func MockAK(outpost api.Outpost, globalConfig api.Config) *APIController {
token: token,
logger: log,
reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
instanceUUID: uuid.New(),
Outpost: outpost,
wsBackoffMultiplier: 1,
refreshHandlers: make([]func(), 0),
reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
instanceUUID: uuid.New(),
Outpost: outpost,
refreshHandlers: make([]func(), 0),
}
ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset")
return ac

View File

@ -127,7 +127,7 @@ func (fe *FlowExecutor) getAnswer(stage StageComponent) string {
return ""
}
func (fe *FlowExecutor) GetSession() *http.Cookie {
func (fe *FlowExecutor) SessionCookie() *http.Cookie {
return fe.session
}

View File

@ -0,0 +1,19 @@
package flow
import "github.com/golang-jwt/jwt/v5"
type SessionCookieClaims struct {
jwt.Claims
SessionID string `json:"sid"`
Authenticated bool `json:"authenticated"`
}
func (fe *FlowExecutor) Session() *jwt.Token {
sc := fe.SessionCookie()
if sc == nil {
return nil
}
t, _, _ := jwt.NewParser().ParseUnverified(sc.Value, &SessionCookieClaims{})
return t
}

View File

@ -38,7 +38,14 @@ func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LD
username, err := instance.binder.GetUsername(bindDN)
if err == nil {
selectedApp = instance.GetAppSlug()
return instance.binder.Bind(username, req)
c, err := instance.binder.Bind(username, req)
if c == ldap.LDAPResultSuccess {
f := instance.GetFlags(req.BindDN)
ls.connectionsSync.Lock()
ls.connections[f.SessionID()] = conn
ls.connectionsSync.Unlock()
}
return c, err
} else {
req.Log().WithError(err).Debug("Username not for instance")
}

View File

@ -27,8 +27,9 @@ func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResul
passed, err := fe.Execute()
flags := flags.UserFlags{
Session: fe.GetSession(),
UserPk: flags.InvalidUserPK,
Session: fe.SessionCookie(),
SessionJWT: fe.Session(),
UserPk: flags.InvalidUserPK,
}
// only set flags if we don't have flags for this DN yet
// as flags are only checked during the bind, we can remember whether a certain DN

View File

@ -0,0 +1,20 @@
package ldap
import "net"
func (ls *LDAPServer) Close(dn string, conn net.Conn) error {
ls.connectionsSync.Lock()
defer ls.connectionsSync.Unlock()
key := ""
for k, c := range ls.connections {
if c == conn {
key = k
break
}
}
if key == "" {
return nil
}
delete(ls.connections, key)
return nil
}

View File

@ -1,16 +1,30 @@
package flags
import (
"crypto/sha256"
"encoding/hex"
"net/http"
"github.com/golang-jwt/jwt/v5"
"goauthentik.io/api/v3"
"goauthentik.io/internal/outpost/flow"
)
const InvalidUserPK = -1
type UserFlags struct {
UserInfo *api.User
UserPk int32
CanSearch bool
Session *http.Cookie
UserInfo *api.User
UserPk int32
CanSearch bool
Session *http.Cookie
SessionJWT *jwt.Token
}
func (uf UserFlags) SessionID() string {
if uf.SessionJWT == nil {
return ""
}
h := sha256.New()
h.Write([]byte(uf.SessionJWT.Claims.(*flow.SessionCookieClaims).SessionID))
return hex.EncodeToString(h.Sum(nil))
}

View File

@ -18,21 +18,26 @@ import (
)
type LDAPServer struct {
s *ldap.Server
log *log.Entry
ac *ak.APIController
cs *ak.CryptoStore
defaultCert *tls.Certificate
providers []*ProviderInstance
s *ldap.Server
log *log.Entry
ac *ak.APIController
cs *ak.CryptoStore
defaultCert *tls.Certificate
providers []*ProviderInstance
connections map[string]net.Conn
connectionsSync sync.Mutex
}
func NewServer(ac *ak.APIController) ak.Outpost {
ls := &LDAPServer{
log: log.WithField("logger", "authentik.outpost.ldap"),
ac: ac,
cs: ak.NewCryptoStore(ac.Client.CryptoApi),
providers: []*ProviderInstance{},
log: log.WithField("logger", "authentik.outpost.ldap"),
ac: ac,
cs: ak.NewCryptoStore(ac.Client.CryptoApi),
providers: []*ProviderInstance{},
connections: map[string]net.Conn{},
connectionsSync: sync.Mutex{},
}
ac.AddEventHandler(ls.handleWSSessionEnd)
s := ldap.NewServer()
s.EnforceLDAP = true
@ -50,6 +55,7 @@ func NewServer(ac *ak.APIController) ak.Outpost {
s.BindFunc("", ls)
s.UnbindFunc("", ls)
s.SearchFunc("", ls)
s.CloseFunc("", ls)
return ls
}
@ -117,3 +123,23 @@ func (ls *LDAPServer) TimerFlowCacheExpiry(ctx context.Context) {
p.binder.TimerFlowCacheExpiry(ctx)
}
}
func (ls *LDAPServer) handleWSSessionEnd(ctx context.Context, msg ak.Event) error {
if msg.Instruction != ak.EventKindSessionEnd {
return nil
}
mmsg := ak.EventArgsSessionEnd{}
err := msg.ArgsAs(&mmsg)
if err != nil {
return err
}
ls.connectionsSync.Lock()
defer ls.connectionsSync.Unlock()
ls.log.Info("Disconnecting session due to session end event")
conn, ok := ls.connections[mmsg.SessionID]
if !ok {
return nil
}
delete(ls.connections, mmsg.SessionID)
return conn.Close()
}

View File

@ -44,38 +44,40 @@ func (ds *DirectSearcher) SearchSubschema(req *search.Request) (ldap.ServerSearc
{
Name: "attributeTypes",
Values: []string{
"( 2.5.4.0 NAME 'objectClass' SYNTAX '1.3.6.1.4.1.1466.115.121.1.38' NO-USER-MODIFICATION )",
"( 2.5.4.4 NAME 'sn' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.3 NAME 'cn' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.6 NAME 'c' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.7 NAME 'l' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.10 NAME 'o' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )",
"( 2.5.4.11 NAME 'ou' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )",
"( 2.5.4.12 NAME 'title' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.13 NAME 'description' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )",
"( 2.5.4.20 NAME 'telephoneNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.31 NAME 'member' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' )",
"( 2.5.4.42 NAME 'givenName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.21.2 NAME 'dITContentRules' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )",
"( 2.5.21.5 NAME 'attributeTypes' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )",
"( 2.5.21.6 NAME 'objectClasses' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )",
"( 0.9.2342.19200300.100.1.1 NAME 'uid' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 0.9.2342.19200300.100.1.3 NAME 'mail' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 0.9.2342.19200300.100.1.41 NAME 'mobile' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.2.13 NAME 'displayName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.2.146 NAME 'company' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.2.102 NAME 'memberOf' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' NO-USER-MODIFICATION )",
"( 1.2.840.113556.1.2.13 NAME 'displayName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.2.131 NAME 'co' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.2.141 NAME 'department' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.2.146 NAME 'company' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.4.1 NAME 'name' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE NO-USER-MODIFICATION )",
"( 1.2.840.113556.1.4.44 NAME 'homeDirectory' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.4.221 NAME 'sAMAccountName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.4.261 NAME 'division' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.4.44 NAME 'homeDirectory' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.4.750 NAME 'groupType' SYNTAX '1.3.6.1.4.1.1466.115.121.1.27' SINGLE-VALUE )",
"( 1.2.840.113556.1.4.782 NAME 'objectCategory' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' SINGLE-VALUE )",
"( 1.3.6.1.1.1.1.0 NAME 'uidNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.27' SINGLE-VALUE )",
"( 1.3.6.1.1.1.1.1 NAME 'gidNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.27' SINGLE-VALUE )",
"( 1.3.6.1.1.1.1.12 NAME 'memberUid' SYNTAX '1.3.6.1.4.1.1466.115.121.1.26' )",
"( 2.5.18.1 NAME 'createTimestamp' SYNTAX 1.3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE NO-USER-MODIFICATION )",
"( 2.5.18.2 NAME 'modifyTimestamp' SYNTAX 1.3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE NO-USER-MODIFICATION )",
"( 2.5.21.2 NAME 'dITContentRules' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )",
"( 2.5.21.5 NAME 'attributeTypes' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )",
"( 2.5.21.6 NAME 'objectClasses' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )",
"( 2.5.4.0 NAME 'objectClass' SYNTAX '1.3.6.1.4.1.1466.115.121.1.38' NO-USER-MODIFICATION )",
"( 2.5.4.10 NAME 'o' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )",
"( 2.5.4.11 NAME 'ou' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )",
"( 2.5.4.12 NAME 'title' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.13 NAME 'description' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )",
"( 2.5.4.20 NAME 'telephoneNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.3 NAME 'cn' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.31 NAME 'member' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' )",
"( 2.5.4.4 NAME 'sn' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.42 NAME 'givenName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.6 NAME 'c' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.7 NAME 'l' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
// Custom attributes
// Temporarily use 1.3.6.1.4.1.26027.1.1 as a base

View File

@ -66,7 +66,7 @@ func NewProxyServer(ac *ak.APIController) ak.Outpost {
globalMux.PathPrefix("/outpost.goauthentik.io/static").HandlerFunc(s.HandleStatic)
globalMux.Path("/outpost.goauthentik.io/ping").HandlerFunc(sentryutils.SentryNoSample(s.HandlePing))
rootMux.PathPrefix("/").HandlerFunc(s.Handle)
ac.AddWSHandler(s.handleWSMessage)
ac.AddEventHandler(s.handleWSMessage)
return s
}

View File

@ -3,48 +3,27 @@ package proxyv2
import (
"context"
"github.com/mitchellh/mapstructure"
"goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/proxyv2/application"
)
type WSProviderSubType string
const (
WSProviderSubTypeLogout WSProviderSubType = "logout"
)
type WSProviderMsg struct {
SubType WSProviderSubType `mapstructure:"sub_type"`
SessionID string `mapstructure:"session_id"`
}
func ParseWSProvider(args map[string]interface{}) (*WSProviderMsg, error) {
msg := &WSProviderMsg{}
err := mapstructure.Decode(args, &msg)
if err != nil {
return nil, err
func (ps *ProxyServer) handleWSMessage(ctx context.Context, msg ak.Event) error {
if msg.Instruction != ak.EventKindSessionEnd {
return nil
}
return msg, nil
}
func (ps *ProxyServer) handleWSMessage(ctx context.Context, args map[string]interface{}) {
msg, err := ParseWSProvider(args)
mmsg := ak.EventArgsSessionEnd{}
err := msg.ArgsAs(&mmsg)
if err != nil {
ps.log.WithError(err).Warning("invalid provider-specific ws message")
return
return err
}
switch msg.SubType {
case WSProviderSubTypeLogout:
for _, p := range ps.apps {
ps.log.WithField("provider", p.Host).Debug("Logging out")
err := p.Logout(ctx, func(c application.Claims) bool {
return c.Sid == msg.SessionID
})
if err != nil {
ps.log.WithField("provider", p.Host).WithError(err).Warning("failed to logout")
}
for _, p := range ps.apps {
ps.log.WithField("provider", p.Host).Debug("Logging out")
err := p.Logout(ctx, func(c application.Claims) bool {
return c.Sid == mmsg.SessionID
})
if err != nil {
ps.log.WithField("provider", p.Host).WithError(err).Warning("failed to logout")
}
default:
ps.log.WithField("sub_type", msg.SubType).Warning("invalid sub_type")
}
return nil
}

View File

@ -6,7 +6,6 @@ import (
"strconv"
"sync"
"github.com/mitchellh/mapstructure"
log "github.com/sirupsen/logrus"
"github.com/wwt/guac"
@ -30,7 +29,7 @@ func NewServer(ac *ak.APIController) ak.Outpost {
connm: sync.RWMutex{},
conns: map[string]connection.Connection{},
}
ac.AddWSHandler(rs.wsHandler)
ac.AddEventHandler(rs.wsHandler)
return rs
}
@ -52,12 +51,14 @@ func parseIntOrZero(input string) int {
return x
}
func (rs *RACServer) wsHandler(ctx context.Context, args map[string]interface{}) {
func (rs *RACServer) wsHandler(ctx context.Context, msg ak.Event) error {
if msg.Instruction != ak.EventKindProviderSpecific {
return nil
}
wsm := WSMessage{}
err := mapstructure.Decode(args, &wsm)
err := msg.ArgsAs(&wsm)
if err != nil {
rs.log.WithError(err).Warning("invalid ws message")
return
return err
}
config := guac.NewGuacamoleConfiguration()
config.Protocol = wsm.Protocol
@ -71,23 +72,23 @@ func (rs *RACServer) wsHandler(ctx context.Context, args map[string]interface{})
}
cc, err := connection.NewConnection(rs.ac, wsm.DestChannelID, config)
if err != nil {
rs.log.WithError(err).Warning("failed to setup connection")
return
return err
}
cc.OnError = func(err error) {
rs.connm.Lock()
delete(rs.conns, wsm.ConnID)
_ = rs.ac.SendWSHello(map[string]interface{}{
_ = rs.ac.SendEventHello(map[string]interface{}{
"active_connections": len(rs.conns),
})
rs.connm.Unlock()
}
rs.connm.Lock()
rs.conns[wsm.ConnID] = *cc
_ = rs.ac.SendWSHello(map[string]interface{}{
_ = rs.ac.SendEventHello(map[string]interface{}{
"active_connections": len(rs.conns),
})
rs.connm.Unlock()
return nil
}
func (rs *RACServer) Start() error {