From f90d6bb3d962da56f4b6b0229e8473905ee28478 Mon Sep 17 00:00:00 2001 From: Jens L Date: Wed, 23 Oct 2024 21:29:18 +0200 Subject: [PATCH] providers/oauth2: fix amr claim not set due to login event not associated (#11780) * providers/oauth2: fix amr claim not set due to login event not associated Signed-off-by: Jens Langhammer * add sid claim Signed-off-by: Jens Langhammer * import engine only once Signed-off-by: Jens Langhammer * remove manual sid extraction from proxy, add test, make session key hashing more obvious Signed-off-by: Jens Langhammer * unrelated string fix Signed-off-by: Jens Langhammer * fix format Signed-off-by: Jens Langhammer * fix tests Signed-off-by: Jens Langhammer --------- Signed-off-by: Jens Langhammer # Conflicts: # tests/e2e/test_provider_proxy.py --- authentik/events/signals.py | 19 ++- authentik/providers/oauth2/id_token.py | 17 ++- ...017_alter_oauth2provider_token_validity.py | 3 +- ..._remove_accesstoken_session_id_and_more.py | 113 ++++++++++++++++++ authentik/providers/oauth2/models.py | 15 ++- authentik/providers/oauth2/signals.py | 5 +- authentik/providers/oauth2/views/authorize.py | 11 +- authentik/providers/oauth2/views/token.py | 11 +- authentik/providers/proxy/tasks.py | 5 +- blueprints/system/providers-proxy.yaml | 4 - tests/e2e/test_provider_oidc.py | 1 + tests/e2e/test_provider_proxy.py | 13 +- .../identification/IdentificationStageForm.ts | 4 +- web/src/elements/oauth/UserAccessTokenList.ts | 2 +- .../elements/oauth/UserRefreshTokenList.ts | 2 +- 15 files changed, 190 insertions(+), 35 deletions(-) create mode 100644 authentik/providers/oauth2/migrations/0022_remove_accesstoken_session_id_and_more.py diff --git a/authentik/events/signals.py b/authentik/events/signals.py index 3f93e5deaa..8f23ccc947 100644 --- a/authentik/events/signals.py +++ b/authentik/events/signals.py @@ -1,13 +1,16 @@ """authentik events signal listener""" +from importlib import import_module from typing import Any +from django.conf import settings from django.contrib.auth.signals import user_logged_in, user_logged_out from django.db.models.signals import post_save, pre_delete from django.dispatch import receiver from django.http import HttpRequest +from rest_framework.request import Request -from authentik.core.models import User +from authentik.core.models import AuthenticatedSession, User from authentik.core.signals import login_failed, password_changed from authentik.events.apps import SYSTEM_TASK_STATUS from authentik.events.models import Event, EventAction, SystemTask @@ -23,6 +26,7 @@ from authentik.stages.user_write.signals import user_write from authentik.tenants.utils import get_current_tenant SESSION_LOGIN_EVENT = "login_event" +_session_engine = import_module(settings.SESSION_ENGINE) @receiver(user_logged_in) @@ -40,11 +44,20 @@ def on_user_logged_in(sender, request: HttpRequest, user: User, **_): kwargs[PLAN_CONTEXT_METHOD_ARGS] = flow_plan.context.get(PLAN_CONTEXT_METHOD_ARGS, {}) event = Event.new(EventAction.LOGIN, **kwargs).from_http(request, user=user) request.session[SESSION_LOGIN_EVENT] = event + request.session.save() -def get_login_event(request: HttpRequest) -> Event | None: +def get_login_event(request_or_session: HttpRequest | AuthenticatedSession | None) -> Event | None: """Wrapper to get login event that can be mocked in tests""" - return request.session.get(SESSION_LOGIN_EVENT, None) + session = None + if not request_or_session: + return None + if isinstance(request_or_session, HttpRequest | Request): + session = request_or_session.session + if isinstance(request_or_session, AuthenticatedSession): + SessionStore = _session_engine.SessionStore + session = SessionStore(request_or_session.session_key) + return session.get(SESSION_LOGIN_EVENT, None) @receiver(user_logged_out) diff --git a/authentik/providers/oauth2/id_token.py b/authentik/providers/oauth2/id_token.py index 7b92804c62..27c5f2600f 100644 --- a/authentik/providers/oauth2/id_token.py +++ b/authentik/providers/oauth2/id_token.py @@ -1,6 +1,7 @@ """id_token utils""" from dataclasses import asdict, dataclass, field +from hashlib import sha256 from typing import TYPE_CHECKING, Any from django.db import models @@ -23,8 +24,13 @@ if TYPE_CHECKING: from authentik.providers.oauth2.models import BaseGrantModel, OAuth2Provider +def hash_session_key(session_key: str) -> str: + """Hash the session key for inclusion in JWTs as `sid`""" + return sha256(session_key.encode("ascii")).hexdigest() + + class SubModes(models.TextChoices): - """Mode after which 'sub' attribute is generateed, for compatibility reasons""" + """Mode after which 'sub' attribute is generated, for compatibility reasons""" HASHED_USER_ID = "hashed_user_id", _("Based on the Hashed User ID") USER_ID = "user_id", _("Based on user ID") @@ -51,7 +57,8 @@ class IDToken: and potentially other requested Claims. The ID Token is represented as a JSON Web Token (JWT) [JWT]. - https://openid.net/specs/openid-connect-core-1_0.html#IDToken""" + https://openid.net/specs/openid-connect-core-1_0.html#IDToken + https://www.iana.org/assignments/jwt/jwt.xhtml""" # Issuer, https://www.rfc-editor.org/rfc/rfc7519.html#section-4.1.1 iss: str | None = None @@ -79,6 +86,8 @@ class IDToken: nonce: str | None = None # Access Token hash value, http://openid.net/specs/openid-connect-core-1_0.html at_hash: str | None = None + # Session ID, https://openid.net/specs/openid-connect-frontchannel-1_0.html#ClaimsContents + sid: str | None = None claims: dict[str, Any] = field(default_factory=dict) @@ -116,9 +125,11 @@ class IDToken: now = timezone.now() id_token.iat = int(now.timestamp()) id_token.auth_time = int(token.auth_time.timestamp()) + if token.session: + id_token.sid = hash_session_key(token.session.session_key) # We use the timestamp of the user's last successful login (EventAction.LOGIN) for auth_time - auth_event = get_login_event(request) + auth_event = get_login_event(token.session) if auth_event: # Also check which method was used for authentication method = auth_event.context.get(PLAN_CONTEXT_METHOD, "") diff --git a/authentik/providers/oauth2/migrations/0007_auto_20201016_1107_squashed_0017_alter_oauth2provider_token_validity.py b/authentik/providers/oauth2/migrations/0007_auto_20201016_1107_squashed_0017_alter_oauth2provider_token_validity.py index 81656d9cc5..6f5eddc402 100644 --- a/authentik/providers/oauth2/migrations/0007_auto_20201016_1107_squashed_0017_alter_oauth2provider_token_validity.py +++ b/authentik/providers/oauth2/migrations/0007_auto_20201016_1107_squashed_0017_alter_oauth2provider_token_validity.py @@ -3,6 +3,7 @@ import django.db.models.deletion from django.apps.registry import Apps from django.db import migrations, models +from django.db.backends.base.schema import BaseDatabaseSchemaEditor import authentik.lib.utils.time @@ -14,7 +15,7 @@ scope_uid_map = { } -def set_managed_flag(apps: Apps, schema_editor): +def set_managed_flag(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): ScopeMapping = apps.get_model("authentik_providers_oauth2", "ScopeMapping") db_alias = schema_editor.connection.alias for mapping in ScopeMapping.objects.using(db_alias).filter(name__startswith="Autogenerated "): diff --git a/authentik/providers/oauth2/migrations/0022_remove_accesstoken_session_id_and_more.py b/authentik/providers/oauth2/migrations/0022_remove_accesstoken_session_id_and_more.py new file mode 100644 index 0000000000..081f45962c --- /dev/null +++ b/authentik/providers/oauth2/migrations/0022_remove_accesstoken_session_id_and_more.py @@ -0,0 +1,113 @@ +# Generated by Django 5.0.9 on 2024-10-23 13:38 + +from hashlib import sha256 +import django.db.models.deletion +from django.db import migrations, models +from django.apps.registry import Apps +from django.db.backends.base.schema import BaseDatabaseSchemaEditor +from authentik.lib.migrations import progress_bar + + +def migrate_session(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): + AuthenticatedSession = apps.get_model("authentik_core", "authenticatedsession") + AuthorizationCode = apps.get_model("authentik_providers_oauth2", "authorizationcode") + AccessToken = apps.get_model("authentik_providers_oauth2", "accesstoken") + RefreshToken = apps.get_model("authentik_providers_oauth2", "refreshtoken") + db_alias = schema_editor.connection.alias + + print(f"\nFetching session keys, this might take a couple of minutes...") + session_ids = {} + for session in progress_bar(AuthenticatedSession.objects.using(db_alias).all()): + session_ids[sha256(session.session_key.encode("ascii")).hexdigest()] = session.session_key + for model in [AuthorizationCode, AccessToken, RefreshToken]: + print( + f"\nAdding session to {model._meta.verbose_name}, this might take a couple of minutes..." + ) + for code in progress_bar(model.objects.using(db_alias).all()): + if code.session_id_old not in session_ids: + continue + code.session = ( + AuthenticatedSession.objects.using(db_alias) + .filter(session_key=session_ids[code.session_id_old]) + .first() + ) + code.save() + + +class Migration(migrations.Migration): + + dependencies = [ + ("authentik_core", "0040_provider_invalidation_flow"), + ("authentik_providers_oauth2", "0021_oauth2provider_encryption_key_and_more"), + ] + + operations = [ + migrations.RenameField( + model_name="accesstoken", + old_name="session_id", + new_name="session_id_old", + ), + migrations.RenameField( + model_name="authorizationcode", + old_name="session_id", + new_name="session_id_old", + ), + migrations.RenameField( + model_name="refreshtoken", + old_name="session_id", + new_name="session_id_old", + ), + migrations.AddField( + model_name="accesstoken", + name="session", + field=models.ForeignKey( + default=None, + null=True, + on_delete=django.db.models.deletion.SET_DEFAULT, + to="authentik_core.authenticatedsession", + ), + ), + migrations.AddField( + model_name="authorizationcode", + name="session", + field=models.ForeignKey( + default=None, + null=True, + on_delete=django.db.models.deletion.SET_DEFAULT, + to="authentik_core.authenticatedsession", + ), + ), + migrations.AddField( + model_name="devicetoken", + name="session", + field=models.ForeignKey( + default=None, + null=True, + on_delete=django.db.models.deletion.SET_DEFAULT, + to="authentik_core.authenticatedsession", + ), + ), + migrations.AddField( + model_name="refreshtoken", + name="session", + field=models.ForeignKey( + default=None, + null=True, + on_delete=django.db.models.deletion.SET_DEFAULT, + to="authentik_core.authenticatedsession", + ), + ), + migrations.RunPython(migrate_session), + migrations.RemoveField( + model_name="accesstoken", + name="session_id_old", + ), + migrations.RemoveField( + model_name="authorizationcode", + name="session_id_old", + ), + migrations.RemoveField( + model_name="refreshtoken", + name="session_id_old", + ), + ] diff --git a/authentik/providers/oauth2/models.py b/authentik/providers/oauth2/models.py index 59ac273701..d56c7afe26 100644 --- a/authentik/providers/oauth2/models.py +++ b/authentik/providers/oauth2/models.py @@ -24,7 +24,13 @@ from rest_framework.serializers import Serializer from structlog.stdlib import get_logger from authentik.brands.models import WebfingerProvider -from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User +from authentik.core.models import ( + AuthenticatedSession, + ExpiringModel, + PropertyMapping, + Provider, + User, +) from authentik.crypto.models import CertificateKeyPair from authentik.lib.generators import generate_code_fixed_length, generate_id, generate_key from authentik.lib.models import SerializerModel @@ -354,7 +360,9 @@ class BaseGrantModel(models.Model): revoked = models.BooleanField(default=False) _scope = models.TextField(default="", verbose_name=_("Scopes")) auth_time = models.DateTimeField(verbose_name="Authentication time") - session_id = models.CharField(default="", blank=True) + session = models.ForeignKey( + AuthenticatedSession, null=True, on_delete=models.SET_DEFAULT, default=None + ) class Meta: abstract = True @@ -486,6 +494,9 @@ class DeviceToken(ExpiringModel): device_code = models.TextField(default=generate_key) user_code = models.TextField(default=generate_code_fixed_length) _scope = models.TextField(default="", verbose_name=_("Scopes")) + session = models.ForeignKey( + AuthenticatedSession, null=True, on_delete=models.SET_DEFAULT, default=None + ) @property def scope(self) -> list[str]: diff --git a/authentik/providers/oauth2/signals.py b/authentik/providers/oauth2/signals.py index ee0f4ed9c8..16996ced33 100644 --- a/authentik/providers/oauth2/signals.py +++ b/authentik/providers/oauth2/signals.py @@ -1,5 +1,3 @@ -from hashlib import sha256 - from django.contrib.auth.signals import user_logged_out from django.dispatch import receiver from django.http import HttpRequest @@ -13,5 +11,4 @@ def user_logged_out_oauth_access_token(sender, request: HttpRequest, user: User, """Revoke access tokens upon user logout""" if not request.session or not request.session.session_key: return - hashed_session_key = sha256(request.session.session_key.encode("ascii")).hexdigest() - AccessToken.objects.filter(user=user, session_id=hashed_session_key).delete() + AccessToken.objects.filter(user=user, session__session_key=request.session.session_key).delete() diff --git a/authentik/providers/oauth2/views/authorize.py b/authentik/providers/oauth2/views/authorize.py index 1f613df45b..156f634f26 100644 --- a/authentik/providers/oauth2/views/authorize.py +++ b/authentik/providers/oauth2/views/authorize.py @@ -2,7 +2,6 @@ from dataclasses import InitVar, dataclass, field from datetime import timedelta -from hashlib import sha256 from json import dumps from re import error as RegexError from re import fullmatch @@ -16,7 +15,7 @@ from django.utils import timezone from django.utils.translation import gettext as _ from structlog.stdlib import get_logger -from authentik.core.models import Application +from authentik.core.models import Application, AuthenticatedSession from authentik.events.models import Event, EventAction from authentik.events.signals import get_login_event from authentik.flows.challenge import ( @@ -319,7 +318,9 @@ class OAuthAuthorizationParams: expires=now + timedelta_from_string(self.provider.access_code_validity), scope=self.scope, nonce=self.nonce, - session_id=sha256(request.session.session_key.encode("ascii")).hexdigest(), + session=AuthenticatedSession.objects.filter( + session_key=request.session.session_key + ).first(), ) if self.code_challenge and self.code_challenge_method: @@ -611,7 +612,9 @@ class OAuthFulfillmentStage(StageView): expires=access_token_expiry, provider=self.provider, auth_time=auth_event.created if auth_event else now, - session_id=sha256(self.request.session.session_key.encode("ascii")).hexdigest(), + session=AuthenticatedSession.objects.filter( + session_key=self.request.session.session_key + ).first(), ) id_token = IDToken.new(self.provider, token, self.request) diff --git a/authentik/providers/oauth2/views/token.py b/authentik/providers/oauth2/views/token.py index 99245ba548..aa3a9fcc26 100644 --- a/authentik/providers/oauth2/views/token.py +++ b/authentik/providers/oauth2/views/token.py @@ -574,7 +574,7 @@ class TokenView(View): # Keep same scopes as previous token scope=self.params.authorization_code.scope, auth_time=self.params.authorization_code.auth_time, - session_id=self.params.authorization_code.session_id, + session=self.params.authorization_code.session, ) access_id_token = IDToken.new( self.provider, @@ -602,7 +602,7 @@ class TokenView(View): expires=refresh_token_expiry, provider=self.provider, auth_time=self.params.authorization_code.auth_time, - session_id=self.params.authorization_code.session_id, + session=self.params.authorization_code.session, ) id_token = IDToken.new( self.provider, @@ -635,7 +635,7 @@ class TokenView(View): # Keep same scopes as previous token scope=self.params.refresh_token.scope, auth_time=self.params.refresh_token.auth_time, - session_id=self.params.refresh_token.session_id, + session=self.params.refresh_token.session, ) access_token.id_token = IDToken.new( self.provider, @@ -651,7 +651,7 @@ class TokenView(View): expires=refresh_token_expiry, provider=self.provider, auth_time=self.params.refresh_token.auth_time, - session_id=self.params.refresh_token.session_id, + session=self.params.refresh_token.session, ) id_token = IDToken.new( self.provider, @@ -709,13 +709,14 @@ class TokenView(View): raise DeviceCodeError("authorization_pending") now = timezone.now() access_token_expiry = now + timedelta_from_string(self.provider.access_token_validity) - auth_event = get_login_event(self.request) + auth_event = get_login_event(self.params.device_code.session) access_token = AccessToken( provider=self.provider, user=self.params.device_code.user, expires=access_token_expiry, scope=self.params.device_code.scope, auth_time=auth_event.created if auth_event else now, + session=self.params.device_code.session, ) access_token.id_token = IDToken.new( self.provider, diff --git a/authentik/providers/proxy/tasks.py b/authentik/providers/proxy/tasks.py index 8b2f94c035..e9c449fcf0 100644 --- a/authentik/providers/proxy/tasks.py +++ b/authentik/providers/proxy/tasks.py @@ -1,13 +1,12 @@ """proxy provider tasks""" -from hashlib import sha256 - from asgiref.sync import async_to_sync from channels.layers import get_channel_layer from django.db import DatabaseError, InternalError, ProgrammingError from authentik.outposts.consumer import OUTPOST_GROUP from authentik.outposts.models import Outpost, OutpostType +from authentik.providers.oauth2.id_token import hash_session_key from authentik.providers.proxy.models import ProxyProvider from authentik.root.celery import CELERY_APP @@ -26,7 +25,7 @@ def proxy_set_defaults(): def proxy_on_logout(session_id: str): """Update outpost instances connected to a single outpost""" layer = get_channel_layer() - hashed_session_id = sha256(session_id.encode("ascii")).hexdigest() + hashed_session_id = hash_session_key(session_id) for outpost in Outpost.objects.filter(type=OutpostType.PROXY): group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} async_to_sync(layer.group_send)( diff --git a/blueprints/system/providers-proxy.yaml b/blueprints/system/providers-proxy.yaml index 2fd4db560f..1214d157d1 100644 --- a/blueprints/system/providers-proxy.yaml +++ b/blueprints/system/providers-proxy.yaml @@ -14,11 +14,7 @@ entries: expression: | # This mapping is used by the authentik proxy. It passes extra user attributes, # which are used for example for the HTTP-Basic Authentication mapping. - session_id = None - if "token" in request.context: - session_id = request.context.get("token").session_id return { - "sid": session_id, "ak_proxy": { "user_attributes": request.user.group_attributes(request), "is_superuser": request.user.is_superuser, diff --git a/tests/e2e/test_provider_oidc.py b/tests/e2e/test_provider_oidc.py index 350c4eaa0d..e51ac90a64 100644 --- a/tests/e2e/test_provider_oidc.py +++ b/tests/e2e/test_provider_oidc.py @@ -171,6 +171,7 @@ class TestProviderOAuth2OIDC(SeleniumTestCase): body = loads(self.driver.find_element(By.CSS_SELECTOR, "pre").text) self.assertEqual(body["IDTokenClaims"]["nickname"], self.user.username) + self.assertEqual(body["IDTokenClaims"]["amr"], ["pwd"]) self.assertEqual(body["UserInfo"]["nickname"], self.user.username) self.assertEqual(body["IDTokenClaims"]["name"], self.user.name) diff --git a/tests/e2e/test_provider_proxy.py b/tests/e2e/test_provider_proxy.py index bb4844dcae..86cf35684b 100644 --- a/tests/e2e/test_provider_proxy.py +++ b/tests/e2e/test_provider_proxy.py @@ -2,6 +2,7 @@ from base64 import b64encode from dataclasses import asdict +from json import loads from sys import platform from time import sleep from typing import Any @@ -10,6 +11,7 @@ from unittest.case import skip, skipUnless from channels.testing import ChannelsLiveServerTestCase from docker.client import DockerClient, from_env from docker.models.containers import Container +from jwt import decode from selenium.webdriver.common.by import By from authentik.blueprints.tests import apply_blueprint, reconcile_app @@ -115,8 +117,15 @@ class TestProviderProxy(SeleniumTestCase): sleep(1) full_body_text = self.driver.find_element(By.CSS_SELECTOR, "pre").text - self.assertIn(f"X-Authentik-Username: {self.user.username}", full_body_text) - self.assertIn("X-Foo: bar", full_body_text) + body = loads(full_body_text) + + self.assertEqual(body["headers"]["X-Authentik-Username"], [self.user.username]) + self.assertEqual(body["headers"]["X-Foo"], ["bar"]) + raw_jwt: str = body["headers"]["X-Authentik-Jwt"][0] + jwt = decode(raw_jwt, options={"verify_signature": False}) + + self.assertIsNotNone(jwt["sid"]) + self.assertIsNotNone(jwt["ak_proxy"]) self.driver.get("http://localhost:9000/outpost.goauthentik.io/sign_out") sleep(2) diff --git a/web/src/admin/stages/identification/IdentificationStageForm.ts b/web/src/admin/stages/identification/IdentificationStageForm.ts index 9123fba71a..6a9c65de08 100644 --- a/web/src/admin/stages/identification/IdentificationStageForm.ts +++ b/web/src/admin/stages/identification/IdentificationStageForm.ts @@ -236,8 +236,8 @@ export class IdentificationStageForm extends BaseStageForm

${msg( diff --git a/web/src/elements/oauth/UserAccessTokenList.ts b/web/src/elements/oauth/UserAccessTokenList.ts index 96ea31616c..53cdd92a34 100644 --- a/web/src/elements/oauth/UserAccessTokenList.ts +++ b/web/src/elements/oauth/UserAccessTokenList.ts @@ -34,7 +34,7 @@ export class UserOAuthAccessTokenList extends Table { } checkbox = true; - order = "expires"; + order = "-expires"; columns(): TableColumn[] { return [ diff --git a/web/src/elements/oauth/UserRefreshTokenList.ts b/web/src/elements/oauth/UserRefreshTokenList.ts index d0581b1d89..f3cd743701 100644 --- a/web/src/elements/oauth/UserRefreshTokenList.ts +++ b/web/src/elements/oauth/UserRefreshTokenList.ts @@ -35,7 +35,7 @@ export class UserOAuthRefreshTokenList extends Table { checkbox = true; clearOnRefresh = true; - order = "expires"; + order = "-expires"; columns(): TableColumn[] { return [