outposts: Refactor session end signal and add LDAP support (#14539)

* outpost: promote session end signal to non-provider specific

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* implement server-side logout in ldap

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix previous import

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* use better retry logic

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* log

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* make more generic if we switch from ws to something else

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* make it possible to e2e test WS

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix ldap session id

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* ok I actually need to go to bed this took me an hour to fix

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* format; add ldap test

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix leftover state

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* remove thread

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* use ws base for radius

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* separate test utils

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* rename

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix missing super calls

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* websocket tests with browser 🎉

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add proxy test for sign out

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix install_id issue with channels tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix proxy basic auth test

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* big code dedupe

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* allow passing go build args

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* improve waiting for outpost

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* rewrite ldap tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* ok actually fix the tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* undo a couple things that need more time to cook

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* remove unused lockfile-lint dependency since we use a shell script and SFE does not have a lockfile

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix session id for ldap

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix missing createTimestamp and modifyTimestamp ldap attributes

closes #10474

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L.
2025-06-10 12:11:21 +02:00
committed by GitHub
parent 8bfc9ab7c9
commit 88fa7e37dc
25 changed files with 347 additions and 249 deletions

View File

@ -11,3 +11,4 @@ blueprints/local
!gen-ts-api/node_modules !gen-ts-api/node_modules
!gen-ts-api/dist/** !gen-ts-api/dist/**
!gen-go-api/ !gen-go-api/
.venv

View File

@ -37,6 +37,9 @@ class WebsocketMessageInstruction(IntEnum):
# Provider specific message # Provider specific message
PROVIDER_SPECIFIC = 3 PROVIDER_SPECIFIC = 3
# Session ended
SESSION_END = 4
@dataclass(slots=True) @dataclass(slots=True)
class WebsocketMessage: class WebsocketMessage:
@ -145,6 +148,14 @@ class OutpostConsumer(JsonWebsocketConsumer):
asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)) asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE))
) )
def event_session_end(self, event):
"""Event handler which is called when a session is ended"""
self.send_json(
asdict(
WebsocketMessage(instruction=WebsocketMessageInstruction.SESSION_END, args=event)
)
)
def event_provider_specific(self, event): def event_provider_specific(self, event):
"""Event handler which can be called by provider-specific """Event handler which can be called by provider-specific
implementations to send specific messages to the outpost""" implementations to send specific messages to the outpost"""

View File

@ -1,17 +1,24 @@
"""authentik outpost signals""" """authentik outpost signals"""
from django.contrib.auth.signals import user_logged_out
from django.core.cache import cache from django.core.cache import cache
from django.db.models import Model from django.db.models import Model
from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save
from django.dispatch import receiver from django.dispatch import receiver
from django.http import HttpRequest
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.brands.models import Brand from authentik.brands.models import Brand
from authentik.core.models import Provider from authentik.core.models import AuthenticatedSession, Provider, User
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.lib.utils.reflection import class_to_path from authentik.lib.utils.reflection import class_to_path
from authentik.outposts.models import Outpost, OutpostServiceConnection from authentik.outposts.models import Outpost, OutpostServiceConnection
from authentik.outposts.tasks import CACHE_KEY_OUTPOST_DOWN, outpost_controller, outpost_post_save from authentik.outposts.tasks import (
CACHE_KEY_OUTPOST_DOWN,
outpost_controller,
outpost_post_save,
outpost_session_end,
)
LOGGER = get_logger() LOGGER = get_logger()
UPDATE_TRIGGERING_MODELS = ( UPDATE_TRIGGERING_MODELS = (
@ -73,3 +80,17 @@ def pre_delete_cleanup(sender, instance: Outpost, **_):
instance.user.delete() instance.user.delete()
cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance) cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance)
outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) outpost_controller.delay(instance.pk.hex, action="down", from_cache=True)
@receiver(user_logged_out)
def logout_revoke_direct(sender: type[User], request: HttpRequest, **_):
"""Catch logout by direct logout and forward to providers"""
if not request.session or not request.session.session_key:
return
outpost_session_end.delay(request.session.session_key)
@receiver(pre_delete, sender=AuthenticatedSession)
def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
"""Catch logout by expiring sessions being deleted"""
outpost_session_end.delay(instance.session.session_key)

View File

@ -1,5 +1,6 @@
"""outpost tasks""" """outpost tasks"""
from hashlib import sha256
from os import R_OK, access from os import R_OK, access
from pathlib import Path from pathlib import Path
from socket import gethostname from socket import gethostname
@ -49,6 +50,11 @@ LOGGER = get_logger()
CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s" CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s"
def hash_session_key(session_key: str) -> str:
"""Hash the session key for sending session end signals"""
return sha256(session_key.encode("ascii")).hexdigest()
def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None: def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None:
"""Get a controller for the outpost, when a service connection is defined""" """Get a controller for the outpost, when a service connection is defined"""
if not outpost.service_connection: if not outpost.service_connection:
@ -289,3 +295,20 @@ def outpost_connection_discovery(self: SystemTask):
url=unix_socket_path, url=unix_socket_path,
) )
self.set_status(TaskStatus.SUCCESSFUL, *messages) self.set_status(TaskStatus.SUCCESSFUL, *messages)
@CELERY_APP.task()
def outpost_session_end(session_id: str):
"""Update outpost instances connected to a single outpost"""
layer = get_channel_layer()
hashed_session_id = hash_session_key(session_id)
for outpost in Outpost.objects.all():
LOGGER.info("Sending session end signal to outpost", outpost=outpost)
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
async_to_sync(layer.group_send)(
group,
{
"type": "event.session.end",
"session_id": hashed_session_id,
},
)

View File

@ -1,23 +0,0 @@
"""Proxy provider signals"""
from django.contrib.auth.signals import user_logged_out
from django.db.models.signals import pre_delete
from django.dispatch import receiver
from django.http import HttpRequest
from authentik.core.models import AuthenticatedSession, User
from authentik.providers.proxy.tasks import proxy_on_logout
@receiver(user_logged_out)
def logout_proxy_revoke_direct(sender: type[User], request: HttpRequest, **_):
"""Catch logout by direct logout and forward to proxy providers"""
if not request.session or not request.session.session_key:
return
proxy_on_logout.delay(request.session.session_key)
@receiver(pre_delete, sender=AuthenticatedSession)
def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
"""Catch logout by expiring sessions being deleted"""
proxy_on_logout.delay(instance.session.session_key)

View File

@ -1,26 +0,0 @@
"""proxy provider tasks"""
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from authentik.outposts.consumer import OUTPOST_GROUP
from authentik.outposts.models import Outpost, OutpostType
from authentik.providers.oauth2.id_token import hash_session_key
from authentik.root.celery import CELERY_APP
@CELERY_APP.task()
def proxy_on_logout(session_id: str):
"""Update outpost instances connected to a single outpost"""
layer = get_channel_layer()
hashed_session_id = hash_session_key(session_id)
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
async_to_sync(layer.group_send)(
group,
{
"type": "event.provider.specific",
"sub_type": "logout",
"session_id": hashed_session_id,
},
)

1
go.mod
View File

@ -4,6 +4,7 @@ go 1.24.0
require ( require (
beryju.io/ldap v0.1.0 beryju.io/ldap v0.1.0
github.com/avast/retry-go/v4 v4.6.1
github.com/coreos/go-oidc/v3 v3.14.1 github.com/coreos/go-oidc/v3 v3.14.1
github.com/getsentry/sentry-go v0.33.0 github.com/getsentry/sentry-go v0.33.0
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1 github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1

2
go.sum
View File

@ -41,6 +41,8 @@ github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7V
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3dyBCFEj5IhUbnKptjxatkF07cF2ak3yi77so=
github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
github.com/avast/retry-go/v4 v4.6.1 h1:VkOLRubHdisGrHnTu89g08aQEWEgRU7LVEop3GbIcMk=
github.com/avast/retry-go/v4 v4.6.1/go.mod h1:V6oF8njAwxJ5gRo1Q7Cxab24xs5NCWZBeaHHBklR8mA=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=

View File

@ -13,6 +13,7 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/avast/retry-go/v4"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -25,8 +26,6 @@ import (
"goauthentik.io/internal/utils/web" "goauthentik.io/internal/utils/web"
) )
type WSHandler func(ctx context.Context, args map[string]interface{})
const ConfigLogLevel = "log_level" const ConfigLogLevel = "log_level"
// APIController main controller which connects to the authentik api via http and ws // APIController main controller which connects to the authentik api via http and ws
@ -43,12 +42,11 @@ type APIController struct {
reloadOffset time.Duration reloadOffset time.Duration
wsConn *websocket.Conn eventConn *websocket.Conn
lastWsReconnect time.Time lastWsReconnect time.Time
wsIsReconnecting bool wsIsReconnecting bool
wsBackoffMultiplier int eventHandlers []EventHandler
wsHandlers []WSHandler refreshHandlers []func()
refreshHandlers []func()
instanceUUID uuid.UUID instanceUUID uuid.UUID
} }
@ -83,20 +81,19 @@ func NewAPIController(akURL url.URL, token string) *APIController {
// Because we don't know the outpost UUID, we simply do a list and pick the first // Because we don't know the outpost UUID, we simply do a list and pick the first
// The service account this token belongs to should only have access to a single outpost // The service account this token belongs to should only have access to a single outpost
var outposts *api.PaginatedOutpostList outposts, _ := retry.DoWithData[*api.PaginatedOutpostList](
var err error func() (*api.PaginatedOutpostList, error) {
for { outposts, _, err := apiClient.OutpostsApi.OutpostsInstancesList(context.Background()).Execute()
outposts, _, err = apiClient.OutpostsApi.OutpostsInstancesList(context.Background()).Execute() return outposts, err
},
if err == nil { retry.Attempts(0),
break retry.Delay(time.Second*3),
} retry.OnRetry(func(attempt uint, err error) {
log.WithError(err).Error("Failed to fetch outpost configuration, retrying in 3 seconds")
log.WithError(err).Error("Failed to fetch outpost configuration, retrying in 3 seconds") }),
time.Sleep(time.Second * 3) )
}
if len(outposts.Results) < 1 { if len(outposts.Results) < 1 {
panic("No outposts found with given token, ensure the given token corresponds to an authenitk Outpost") log.Panic("No outposts found with given token, ensure the given token corresponds to an authenitk Outpost")
} }
outpost := outposts.Results[0] outpost := outposts.Results[0]
@ -119,17 +116,16 @@ func NewAPIController(akURL url.URL, token string) *APIController {
token: token, token: token,
logger: log, logger: log,
reloadOffset: time.Duration(rand.Intn(10)) * time.Second, reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
instanceUUID: uuid.New(), instanceUUID: uuid.New(),
Outpost: outpost, Outpost: outpost,
wsHandlers: []WSHandler{}, eventHandlers: []EventHandler{},
wsBackoffMultiplier: 1, refreshHandlers: make([]func(), 0),
refreshHandlers: make([]func(), 0),
} }
ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset") ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset")
err = ac.initWS(akURL, outpost.Pk) err = ac.initEvent(akURL, outpost.Pk)
if err != nil { if err != nil {
go ac.reconnectWS() go ac.recentEvents()
} }
ac.configureRefreshSignal() ac.configureRefreshSignal()
return ac return ac
@ -200,7 +196,7 @@ func (a *APIController) OnRefresh() error {
return err return err
} }
func (a *APIController) getWebsocketPingArgs() map[string]interface{} { func (a *APIController) getEventPingArgs() map[string]interface{} {
args := map[string]interface{}{ args := map[string]interface{}{
"version": constants.VERSION, "version": constants.VERSION,
"buildHash": constants.BUILD(""), "buildHash": constants.BUILD(""),
@ -226,12 +222,12 @@ func (a *APIController) StartBackgroundTasks() error {
"build": constants.BUILD(""), "build": constants.BUILD(""),
}).Set(1) }).Set(1)
go func() { go func() {
a.logger.Debug("Starting WS Handler...") a.logger.Debug("Starting Event Handler...")
a.startWSHandler() a.startEventHandler()
}() }()
go func() { go func() {
a.logger.Debug("Starting WS Health notifier...") a.logger.Debug("Starting Event health notifier...")
a.startWSHealth() a.startEventHealth()
}() }()
go func() { go func() {
a.logger.Debug("Starting Interval updater...") a.logger.Debug("Starting Interval updater...")

View File

@ -11,6 +11,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/avast/retry-go/v4"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"goauthentik.io/internal/config" "goauthentik.io/internal/config"
@ -30,7 +31,7 @@ func (ac *APIController) getWebsocketURL(akURL url.URL, outpostUUID string, quer
return wsUrl return wsUrl
} }
func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error { func (ac *APIController) initEvent(akURL url.URL, outpostUUID string) error {
query := akURL.Query() query := akURL.Query()
query.Set("instance_uuid", ac.instanceUUID.String()) query.Set("instance_uuid", ac.instanceUUID.String())
@ -57,19 +58,19 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
return err return err
} }
ac.wsConn = ws ac.eventConn = ws
// Send hello message with our version // Send hello message with our version
msg := websocketMessage{ msg := Event{
Instruction: WebsocketInstructionHello, Instruction: EventKindHello,
Args: ac.getWebsocketPingArgs(), Args: ac.getEventPingArgs(),
} }
err = ws.WriteJSON(msg) err = ws.WriteJSON(msg)
if err != nil { if err != nil {
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithError(err).Warning("Failed to hello to authentik") ac.logger.WithField("logger", "authentik.outpost.events").WithError(err).Warning("Failed to hello to authentik")
return err return err
} }
ac.lastWsReconnect = time.Now() ac.lastWsReconnect = time.Now()
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID).Info("Successfully connected websocket") ac.logger.WithField("logger", "authentik.outpost.events").WithField("outpost", outpostUUID).Info("Successfully connected websocket")
return nil return nil
} }
@ -77,19 +78,19 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
func (ac *APIController) Shutdown() { func (ac *APIController) Shutdown() {
// Cleanly close the connection by sending a close message and then // Cleanly close the connection by sending a close message and then
// waiting (with timeout) for the server to close the connection. // waiting (with timeout) for the server to close the connection.
err := ac.wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) err := ac.eventConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil { if err != nil {
ac.logger.WithError(err).Warning("failed to write close message") ac.logger.WithError(err).Warning("failed to write close message")
return return
} }
err = ac.wsConn.Close() err = ac.eventConn.Close()
if err != nil { if err != nil {
ac.logger.WithError(err).Warning("failed to close websocket") ac.logger.WithError(err).Warning("failed to close websocket")
} }
ac.logger.Info("finished shutdown") ac.logger.Info("finished shutdown")
} }
func (ac *APIController) reconnectWS() { func (ac *APIController) recentEvents() {
if ac.wsIsReconnecting { if ac.wsIsReconnecting {
return return
} }
@ -100,46 +101,47 @@ func (ac *APIController) reconnectWS() {
Path: strings.ReplaceAll(ac.Client.GetConfig().Servers[0].URL, "api/v3", ""), Path: strings.ReplaceAll(ac.Client.GetConfig().Servers[0].URL, "api/v3", ""),
} }
attempt := 1 attempt := 1
for { _ = retry.Do(
q := u.Query() func() error {
q.Set("attempt", strconv.Itoa(attempt)) q := u.Query()
u.RawQuery = q.Encode() q.Set("attempt", strconv.Itoa(attempt))
err := ac.initWS(u, ac.Outpost.Pk) u.RawQuery = q.Encode()
attempt += 1 err := ac.initEvent(u, ac.Outpost.Pk)
if err != nil { attempt += 1
ac.logger.Infof("waiting %d seconds to reconnect", ac.wsBackoffMultiplier) if err != nil {
time.Sleep(time.Duration(ac.wsBackoffMultiplier) * time.Second) return err
ac.wsBackoffMultiplier = ac.wsBackoffMultiplier * 2
// Limit to 300 seconds (5m)
if ac.wsBackoffMultiplier >= 300 {
ac.wsBackoffMultiplier = 300
} }
} else {
ac.wsIsReconnecting = false ac.wsIsReconnecting = false
ac.wsBackoffMultiplier = 1 return nil
return },
} retry.Delay(1*time.Second),
} retry.MaxDelay(5*time.Minute),
retry.DelayType(retry.BackOffDelay),
retry.Attempts(0),
retry.OnRetry(func(attempt uint, err error) {
ac.logger.Infof("waiting %d seconds to reconnect", attempt)
}),
)
} }
func (ac *APIController) startWSHandler() { func (ac *APIController) startEventHandler() {
logger := ac.logger.WithField("loop", "ws-handler") logger := ac.logger.WithField("loop", "event-handler")
for { for {
var wsMsg websocketMessage var wsMsg Event
if ac.wsConn == nil { if ac.eventConn == nil {
go ac.reconnectWS() go ac.recentEvents()
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
continue continue
} }
err := ac.wsConn.ReadJSON(&wsMsg) err := ac.eventConn.ReadJSON(&wsMsg)
if err != nil { if err != nil {
ConnectionStatus.With(prometheus.Labels{ ConnectionStatus.With(prometheus.Labels{
"outpost_name": ac.Outpost.Name, "outpost_name": ac.Outpost.Name,
"outpost_type": ac.Server.Type(), "outpost_type": ac.Server.Type(),
"uuid": ac.instanceUUID.String(), "uuid": ac.instanceUUID.String(),
}).Set(0) }).Set(0)
logger.WithError(err).Warning("ws read error") logger.WithError(err).Warning("event read error")
go ac.reconnectWS() go ac.recentEvents()
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
continue continue
} }
@ -149,7 +151,8 @@ func (ac *APIController) startWSHandler() {
"uuid": ac.instanceUUID.String(), "uuid": ac.instanceUUID.String(),
}).Set(1) }).Set(1)
switch wsMsg.Instruction { switch wsMsg.Instruction {
case WebsocketInstructionTriggerUpdate: case EventKindAck:
case EventKindTriggerUpdate:
time.Sleep(ac.reloadOffset) time.Sleep(ac.reloadOffset)
logger.Debug("Got update trigger...") logger.Debug("Got update trigger...")
err := ac.OnRefresh() err := ac.OnRefresh()
@ -164,30 +167,33 @@ func (ac *APIController) startWSHandler() {
"build": constants.BUILD(""), "build": constants.BUILD(""),
}).SetToCurrentTime() }).SetToCurrentTime()
} }
case WebsocketInstructionProviderSpecific: default:
for _, h := range ac.wsHandlers { for _, h := range ac.eventHandlers {
h(context.Background(), wsMsg.Args) err := h(context.Background(), wsMsg)
if err != nil {
ac.logger.WithError(err).Warning("failed to run event handler")
}
} }
} }
} }
} }
func (ac *APIController) startWSHealth() { func (ac *APIController) startEventHealth() {
ticker := time.NewTicker(time.Second * 10) ticker := time.NewTicker(time.Second * 10)
for ; true; <-ticker.C { for ; true; <-ticker.C {
if ac.wsConn == nil { if ac.eventConn == nil {
go ac.reconnectWS() go ac.recentEvents()
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
continue continue
} }
err := ac.SendWSHello(map[string]interface{}{}) err := ac.SendEventHello(map[string]interface{}{})
if err != nil { if err != nil {
ac.logger.WithField("loop", "ws-health").WithError(err).Warning("ws write error") ac.logger.WithField("loop", "event-health").WithError(err).Warning("event write error")
go ac.reconnectWS() go ac.recentEvents()
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
continue continue
} else { } else {
ac.logger.WithField("loop", "ws-health").Trace("hello'd") ac.logger.WithField("loop", "event-health").Trace("hello'd")
ConnectionStatus.With(prometheus.Labels{ ConnectionStatus.With(prometheus.Labels{
"outpost_name": ac.Outpost.Name, "outpost_name": ac.Outpost.Name,
"outpost_type": ac.Server.Type(), "outpost_type": ac.Server.Type(),
@ -230,19 +236,19 @@ func (ac *APIController) startIntervalUpdater() {
} }
} }
func (a *APIController) AddWSHandler(handler WSHandler) { func (a *APIController) AddEventHandler(handler EventHandler) {
a.wsHandlers = append(a.wsHandlers, handler) a.eventHandlers = append(a.eventHandlers, handler)
} }
func (a *APIController) SendWSHello(args map[string]interface{}) error { func (a *APIController) SendEventHello(args map[string]interface{}) error {
allArgs := a.getWebsocketPingArgs() allArgs := a.getEventPingArgs()
for key, value := range args { for key, value := range args {
allArgs[key] = value allArgs[key] = value
} }
aliveMsg := websocketMessage{ aliveMsg := Event{
Instruction: WebsocketInstructionHello, Instruction: EventKindHello,
Args: allArgs, Args: allArgs,
} }
err := a.wsConn.WriteJSON(aliveMsg) err := a.eventConn.WriteJSON(aliveMsg)
return err return err
} }

View File

@ -0,0 +1,37 @@
package ak
import (
"context"
"github.com/mitchellh/mapstructure"
)
type EventKind int
const (
// Code used to acknowledge a previous message
EventKindAck EventKind = 0
// Code used to send a healthcheck keepalive
EventKindHello EventKind = 1
// Code received to trigger a config update
EventKindTriggerUpdate EventKind = 2
// Code received to trigger some provider specific function
EventKindProviderSpecific EventKind = 3
// Code received to identify the end of a session
EventKindSessionEnd EventKind = 4
)
type EventHandler func(ctx context.Context, msg Event) error
type Event struct {
Instruction EventKind `json:"instruction"`
Args interface{} `json:"args"`
}
func (wm Event) ArgsAs(out interface{}) error {
return mapstructure.Decode(wm.Args, out)
}
type EventArgsSessionEnd struct {
SessionID string `mapstructure:"session_id"`
}

View File

@ -15,7 +15,7 @@ func URLMustParse(u string) *url.URL {
return ur return ur
} }
func TestWebsocketURL(t *testing.T) { func TestEventWebsocketURL(t *testing.T) {
u := URLMustParse("http://localhost:9000?foo=bar") u := URLMustParse("http://localhost:9000?foo=bar")
uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77" uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77"
ac := &APIController{} ac := &APIController{}
@ -23,7 +23,7 @@ func TestWebsocketURL(t *testing.T) {
assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/?foo=bar", nu.String()) assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/?foo=bar", nu.String())
} }
func TestWebsocketURL_Query(t *testing.T) { func TestEventWebsocketURL_Query(t *testing.T) {
u := URLMustParse("http://localhost:9000?foo=bar") u := URLMustParse("http://localhost:9000?foo=bar")
uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77" uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77"
ac := &APIController{} ac := &APIController{}
@ -33,7 +33,7 @@ func TestWebsocketURL_Query(t *testing.T) {
assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/?bar=baz&foo=bar", nu.String()) assert.Equal(t, "ws://localhost:9000/ws/outpost/23470845-7263-4fe3-bd79-ec1d7bf77d77/?bar=baz&foo=bar", nu.String())
} }
func TestWebsocketURL_Subpath(t *testing.T) { func TestEventWebsocketURL_Subpath(t *testing.T) {
u := URLMustParse("http://localhost:9000/foo/bar/") u := URLMustParse("http://localhost:9000/foo/bar/")
uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77" uuid := "23470845-7263-4fe3-bd79-ec1d7bf77d77"
ac := &APIController{} ac := &APIController{}

View File

@ -1,19 +0,0 @@
package ak
type websocketInstruction int
const (
// WebsocketInstructionAck Code used to acknowledge a previous message
WebsocketInstructionAck websocketInstruction = 0
// WebsocketInstructionHello Code used to send a healthcheck keepalive
WebsocketInstructionHello websocketInstruction = 1
// WebsocketInstructionTriggerUpdate Code received to trigger a config update
WebsocketInstructionTriggerUpdate websocketInstruction = 2
// WebsocketInstructionProviderSpecific Code received to trigger some provider specific function
WebsocketInstructionProviderSpecific websocketInstruction = 3
)
type websocketMessage struct {
Instruction websocketInstruction `json:"instruction"`
Args map[string]interface{} `json:"args"`
}

View File

@ -55,11 +55,10 @@ func MockAK(outpost api.Outpost, globalConfig api.Config) *APIController {
token: token, token: token,
logger: log, logger: log,
reloadOffset: time.Duration(rand.Intn(10)) * time.Second, reloadOffset: time.Duration(rand.Intn(10)) * time.Second,
instanceUUID: uuid.New(), instanceUUID: uuid.New(),
Outpost: outpost, Outpost: outpost,
wsBackoffMultiplier: 1, refreshHandlers: make([]func(), 0),
refreshHandlers: make([]func(), 0),
} }
ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset") ac.logger.WithField("offset", ac.reloadOffset.String()).Debug("HA Reload offset")
return ac return ac

View File

@ -127,7 +127,7 @@ func (fe *FlowExecutor) getAnswer(stage StageComponent) string {
return "" return ""
} }
func (fe *FlowExecutor) GetSession() *http.Cookie { func (fe *FlowExecutor) SessionCookie() *http.Cookie {
return fe.session return fe.session
} }

View File

@ -0,0 +1,19 @@
package flow
import "github.com/golang-jwt/jwt/v5"
type SessionCookieClaims struct {
jwt.Claims
SessionID string `json:"sid"`
Authenticated bool `json:"authenticated"`
}
func (fe *FlowExecutor) Session() *jwt.Token {
sc := fe.SessionCookie()
if sc == nil {
return nil
}
t, _, _ := jwt.NewParser().ParseUnverified(sc.Value, &SessionCookieClaims{})
return t
}

View File

@ -38,7 +38,14 @@ func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LD
username, err := instance.binder.GetUsername(bindDN) username, err := instance.binder.GetUsername(bindDN)
if err == nil { if err == nil {
selectedApp = instance.GetAppSlug() selectedApp = instance.GetAppSlug()
return instance.binder.Bind(username, req) c, err := instance.binder.Bind(username, req)
if c == ldap.LDAPResultSuccess {
f := instance.GetFlags(req.BindDN)
ls.connectionsSync.Lock()
ls.connections[f.SessionID()] = conn
ls.connectionsSync.Unlock()
}
return c, err
} else { } else {
req.Log().WithError(err).Debug("Username not for instance") req.Log().WithError(err).Debug("Username not for instance")
} }

View File

@ -27,8 +27,9 @@ func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResul
passed, err := fe.Execute() passed, err := fe.Execute()
flags := flags.UserFlags{ flags := flags.UserFlags{
Session: fe.GetSession(), Session: fe.SessionCookie(),
UserPk: flags.InvalidUserPK, SessionJWT: fe.Session(),
UserPk: flags.InvalidUserPK,
} }
// only set flags if we don't have flags for this DN yet // only set flags if we don't have flags for this DN yet
// as flags are only checked during the bind, we can remember whether a certain DN // as flags are only checked during the bind, we can remember whether a certain DN

View File

@ -0,0 +1,20 @@
package ldap
import "net"
func (ls *LDAPServer) Close(dn string, conn net.Conn) error {
ls.connectionsSync.Lock()
defer ls.connectionsSync.Unlock()
key := ""
for k, c := range ls.connections {
if c == conn {
key = k
break
}
}
if key == "" {
return nil
}
delete(ls.connections, key)
return nil
}

View File

@ -1,16 +1,30 @@
package flags package flags
import ( import (
"crypto/sha256"
"encoding/hex"
"net/http" "net/http"
"github.com/golang-jwt/jwt/v5"
"goauthentik.io/api/v3" "goauthentik.io/api/v3"
"goauthentik.io/internal/outpost/flow"
) )
const InvalidUserPK = -1 const InvalidUserPK = -1
type UserFlags struct { type UserFlags struct {
UserInfo *api.User UserInfo *api.User
UserPk int32 UserPk int32
CanSearch bool CanSearch bool
Session *http.Cookie Session *http.Cookie
SessionJWT *jwt.Token
}
func (uf UserFlags) SessionID() string {
if uf.SessionJWT == nil {
return ""
}
h := sha256.New()
h.Write([]byte(uf.SessionJWT.Claims.(*flow.SessionCookieClaims).SessionID))
return hex.EncodeToString(h.Sum(nil))
} }

View File

@ -18,21 +18,26 @@ import (
) )
type LDAPServer struct { type LDAPServer struct {
s *ldap.Server s *ldap.Server
log *log.Entry log *log.Entry
ac *ak.APIController ac *ak.APIController
cs *ak.CryptoStore cs *ak.CryptoStore
defaultCert *tls.Certificate defaultCert *tls.Certificate
providers []*ProviderInstance providers []*ProviderInstance
connections map[string]net.Conn
connectionsSync sync.Mutex
} }
func NewServer(ac *ak.APIController) ak.Outpost { func NewServer(ac *ak.APIController) ak.Outpost {
ls := &LDAPServer{ ls := &LDAPServer{
log: log.WithField("logger", "authentik.outpost.ldap"), log: log.WithField("logger", "authentik.outpost.ldap"),
ac: ac, ac: ac,
cs: ak.NewCryptoStore(ac.Client.CryptoApi), cs: ak.NewCryptoStore(ac.Client.CryptoApi),
providers: []*ProviderInstance{}, providers: []*ProviderInstance{},
connections: map[string]net.Conn{},
connectionsSync: sync.Mutex{},
} }
ac.AddEventHandler(ls.handleWSSessionEnd)
s := ldap.NewServer() s := ldap.NewServer()
s.EnforceLDAP = true s.EnforceLDAP = true
@ -50,6 +55,7 @@ func NewServer(ac *ak.APIController) ak.Outpost {
s.BindFunc("", ls) s.BindFunc("", ls)
s.UnbindFunc("", ls) s.UnbindFunc("", ls)
s.SearchFunc("", ls) s.SearchFunc("", ls)
s.CloseFunc("", ls)
return ls return ls
} }
@ -117,3 +123,23 @@ func (ls *LDAPServer) TimerFlowCacheExpiry(ctx context.Context) {
p.binder.TimerFlowCacheExpiry(ctx) p.binder.TimerFlowCacheExpiry(ctx)
} }
} }
func (ls *LDAPServer) handleWSSessionEnd(ctx context.Context, msg ak.Event) error {
if msg.Instruction != ak.EventKindSessionEnd {
return nil
}
mmsg := ak.EventArgsSessionEnd{}
err := msg.ArgsAs(&mmsg)
if err != nil {
return err
}
ls.connectionsSync.Lock()
defer ls.connectionsSync.Unlock()
ls.log.Info("Disconnecting session due to session end event")
conn, ok := ls.connections[mmsg.SessionID]
if !ok {
return nil
}
delete(ls.connections, mmsg.SessionID)
return conn.Close()
}

View File

@ -44,38 +44,40 @@ func (ds *DirectSearcher) SearchSubschema(req *search.Request) (ldap.ServerSearc
{ {
Name: "attributeTypes", Name: "attributeTypes",
Values: []string{ Values: []string{
"( 2.5.4.0 NAME 'objectClass' SYNTAX '1.3.6.1.4.1.1466.115.121.1.38' NO-USER-MODIFICATION )",
"( 2.5.4.4 NAME 'sn' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.3 NAME 'cn' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.6 NAME 'c' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.7 NAME 'l' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.10 NAME 'o' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )",
"( 2.5.4.11 NAME 'ou' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )",
"( 2.5.4.12 NAME 'title' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.13 NAME 'description' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )",
"( 2.5.4.20 NAME 'telephoneNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.31 NAME 'member' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' )",
"( 2.5.4.42 NAME 'givenName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.21.2 NAME 'dITContentRules' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )",
"( 2.5.21.5 NAME 'attributeTypes' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )",
"( 2.5.21.6 NAME 'objectClasses' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )",
"( 0.9.2342.19200300.100.1.1 NAME 'uid' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", "( 0.9.2342.19200300.100.1.1 NAME 'uid' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 0.9.2342.19200300.100.1.3 NAME 'mail' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", "( 0.9.2342.19200300.100.1.3 NAME 'mail' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 0.9.2342.19200300.100.1.41 NAME 'mobile' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", "( 0.9.2342.19200300.100.1.41 NAME 'mobile' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.2.13 NAME 'displayName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.2.146 NAME 'company' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.2.102 NAME 'memberOf' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' NO-USER-MODIFICATION )", "( 1.2.840.113556.1.2.102 NAME 'memberOf' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' NO-USER-MODIFICATION )",
"( 1.2.840.113556.1.2.13 NAME 'displayName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.2.131 NAME 'co' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", "( 1.2.840.113556.1.2.131 NAME 'co' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.2.141 NAME 'department' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", "( 1.2.840.113556.1.2.141 NAME 'department' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.2.146 NAME 'company' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.4.1 NAME 'name' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE NO-USER-MODIFICATION )", "( 1.2.840.113556.1.4.1 NAME 'name' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE NO-USER-MODIFICATION )",
"( 1.2.840.113556.1.4.44 NAME 'homeDirectory' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.4.221 NAME 'sAMAccountName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", "( 1.2.840.113556.1.4.221 NAME 'sAMAccountName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.4.261 NAME 'division' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )", "( 1.2.840.113556.1.4.261 NAME 'division' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.4.44 NAME 'homeDirectory' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 1.2.840.113556.1.4.750 NAME 'groupType' SYNTAX '1.3.6.1.4.1.1466.115.121.1.27' SINGLE-VALUE )", "( 1.2.840.113556.1.4.750 NAME 'groupType' SYNTAX '1.3.6.1.4.1.1466.115.121.1.27' SINGLE-VALUE )",
"( 1.2.840.113556.1.4.782 NAME 'objectCategory' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' SINGLE-VALUE )", "( 1.2.840.113556.1.4.782 NAME 'objectCategory' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' SINGLE-VALUE )",
"( 1.3.6.1.1.1.1.0 NAME 'uidNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.27' SINGLE-VALUE )", "( 1.3.6.1.1.1.1.0 NAME 'uidNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.27' SINGLE-VALUE )",
"( 1.3.6.1.1.1.1.1 NAME 'gidNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.27' SINGLE-VALUE )", "( 1.3.6.1.1.1.1.1 NAME 'gidNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.27' SINGLE-VALUE )",
"( 1.3.6.1.1.1.1.12 NAME 'memberUid' SYNTAX '1.3.6.1.4.1.1466.115.121.1.26' )", "( 1.3.6.1.1.1.1.12 NAME 'memberUid' SYNTAX '1.3.6.1.4.1.1466.115.121.1.26' )",
"( 2.5.18.1 NAME 'createTimestamp' SYNTAX 1.3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE NO-USER-MODIFICATION )",
"( 2.5.18.2 NAME 'modifyTimestamp' SYNTAX 1.3.6.1.4.1.1466.115.121.1.24 SINGLE-VALUE NO-USER-MODIFICATION )",
"( 2.5.21.2 NAME 'dITContentRules' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )",
"( 2.5.21.5 NAME 'attributeTypes' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )",
"( 2.5.21.6 NAME 'objectClasses' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' NO-USER-MODIFICATION )",
"( 2.5.4.0 NAME 'objectClass' SYNTAX '1.3.6.1.4.1.1466.115.121.1.38' NO-USER-MODIFICATION )",
"( 2.5.4.10 NAME 'o' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )",
"( 2.5.4.11 NAME 'ou' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )",
"( 2.5.4.12 NAME 'title' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.13 NAME 'description' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' )",
"( 2.5.4.20 NAME 'telephoneNumber' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.3 NAME 'cn' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.31 NAME 'member' SYNTAX '1.3.6.1.4.1.1466.115.121.1.12' )",
"( 2.5.4.4 NAME 'sn' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.42 NAME 'givenName' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.6 NAME 'c' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
"( 2.5.4.7 NAME 'l' SYNTAX '1.3.6.1.4.1.1466.115.121.1.15' SINGLE-VALUE )",
// Custom attributes // Custom attributes
// Temporarily use 1.3.6.1.4.1.26027.1.1 as a base // Temporarily use 1.3.6.1.4.1.26027.1.1 as a base

View File

@ -66,7 +66,7 @@ func NewProxyServer(ac *ak.APIController) ak.Outpost {
globalMux.PathPrefix("/outpost.goauthentik.io/static").HandlerFunc(s.HandleStatic) globalMux.PathPrefix("/outpost.goauthentik.io/static").HandlerFunc(s.HandleStatic)
globalMux.Path("/outpost.goauthentik.io/ping").HandlerFunc(sentryutils.SentryNoSample(s.HandlePing)) globalMux.Path("/outpost.goauthentik.io/ping").HandlerFunc(sentryutils.SentryNoSample(s.HandlePing))
rootMux.PathPrefix("/").HandlerFunc(s.Handle) rootMux.PathPrefix("/").HandlerFunc(s.Handle)
ac.AddWSHandler(s.handleWSMessage) ac.AddEventHandler(s.handleWSMessage)
return s return s
} }

View File

@ -3,48 +3,27 @@ package proxyv2
import ( import (
"context" "context"
"github.com/mitchellh/mapstructure" "goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/proxyv2/application" "goauthentik.io/internal/outpost/proxyv2/application"
) )
type WSProviderSubType string func (ps *ProxyServer) handleWSMessage(ctx context.Context, msg ak.Event) error {
if msg.Instruction != ak.EventKindSessionEnd {
const ( return nil
WSProviderSubTypeLogout WSProviderSubType = "logout"
)
type WSProviderMsg struct {
SubType WSProviderSubType `mapstructure:"sub_type"`
SessionID string `mapstructure:"session_id"`
}
func ParseWSProvider(args map[string]interface{}) (*WSProviderMsg, error) {
msg := &WSProviderMsg{}
err := mapstructure.Decode(args, &msg)
if err != nil {
return nil, err
} }
return msg, nil mmsg := ak.EventArgsSessionEnd{}
} err := msg.ArgsAs(&mmsg)
func (ps *ProxyServer) handleWSMessage(ctx context.Context, args map[string]interface{}) {
msg, err := ParseWSProvider(args)
if err != nil { if err != nil {
ps.log.WithError(err).Warning("invalid provider-specific ws message") return err
return
} }
switch msg.SubType { for _, p := range ps.apps {
case WSProviderSubTypeLogout: ps.log.WithField("provider", p.Host).Debug("Logging out")
for _, p := range ps.apps { err := p.Logout(ctx, func(c application.Claims) bool {
ps.log.WithField("provider", p.Host).Debug("Logging out") return c.Sid == mmsg.SessionID
err := p.Logout(ctx, func(c application.Claims) bool { })
return c.Sid == msg.SessionID if err != nil {
}) ps.log.WithField("provider", p.Host).WithError(err).Warning("failed to logout")
if err != nil {
ps.log.WithField("provider", p.Host).WithError(err).Warning("failed to logout")
}
} }
default:
ps.log.WithField("sub_type", msg.SubType).Warning("invalid sub_type")
} }
return nil
} }

View File

@ -6,7 +6,6 @@ import (
"strconv" "strconv"
"sync" "sync"
"github.com/mitchellh/mapstructure"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/wwt/guac" "github.com/wwt/guac"
@ -30,7 +29,7 @@ func NewServer(ac *ak.APIController) ak.Outpost {
connm: sync.RWMutex{}, connm: sync.RWMutex{},
conns: map[string]connection.Connection{}, conns: map[string]connection.Connection{},
} }
ac.AddWSHandler(rs.wsHandler) ac.AddEventHandler(rs.wsHandler)
return rs return rs
} }
@ -52,12 +51,14 @@ func parseIntOrZero(input string) int {
return x return x
} }
func (rs *RACServer) wsHandler(ctx context.Context, args map[string]interface{}) { func (rs *RACServer) wsHandler(ctx context.Context, msg ak.Event) error {
if msg.Instruction != ak.EventKindProviderSpecific {
return nil
}
wsm := WSMessage{} wsm := WSMessage{}
err := mapstructure.Decode(args, &wsm) err := msg.ArgsAs(&wsm)
if err != nil { if err != nil {
rs.log.WithError(err).Warning("invalid ws message") return err
return
} }
config := guac.NewGuacamoleConfiguration() config := guac.NewGuacamoleConfiguration()
config.Protocol = wsm.Protocol config.Protocol = wsm.Protocol
@ -71,23 +72,23 @@ func (rs *RACServer) wsHandler(ctx context.Context, args map[string]interface{})
} }
cc, err := connection.NewConnection(rs.ac, wsm.DestChannelID, config) cc, err := connection.NewConnection(rs.ac, wsm.DestChannelID, config)
if err != nil { if err != nil {
rs.log.WithError(err).Warning("failed to setup connection") return err
return
} }
cc.OnError = func(err error) { cc.OnError = func(err error) {
rs.connm.Lock() rs.connm.Lock()
delete(rs.conns, wsm.ConnID) delete(rs.conns, wsm.ConnID)
_ = rs.ac.SendWSHello(map[string]interface{}{ _ = rs.ac.SendEventHello(map[string]interface{}{
"active_connections": len(rs.conns), "active_connections": len(rs.conns),
}) })
rs.connm.Unlock() rs.connm.Unlock()
} }
rs.connm.Lock() rs.connm.Lock()
rs.conns[wsm.ConnID] = *cc rs.conns[wsm.ConnID] = *cc
_ = rs.ac.SendWSHello(map[string]interface{}{ _ = rs.ac.SendEventHello(map[string]interface{}{
"active_connections": len(rs.conns), "active_connections": len(rs.conns),
}) })
rs.connm.Unlock() rs.connm.Unlock()
return nil
} }
func (rs *RACServer) Start() error { func (rs *RACServer) Start() error {