From 8b74b839832cc2e514440c91f3a28450fa7f9f7a Mon Sep 17 00:00:00 2001 From: Jens L Date: Thu, 9 May 2024 19:04:32 +0200 Subject: [PATCH] core: fix source flow_manager not always appending save stage (#9659) Signed-off-by: Jens Langhammer --- authentik/core/sources/flow_manager.py | 16 +++---- authentik/core/sources/stage.py | 14 +++--- .../core/tests/test_source_flow_manager.py | 46 +++++++++++++------ 3 files changed, 47 insertions(+), 29 deletions(-) diff --git a/authentik/core/sources/flow_manager.py b/authentik/core/sources/flow_manager.py index 183ca1491a..b5828debc1 100644 --- a/authentik/core/sources/flow_manager.py +++ b/authentik/core/sources/flow_manager.py @@ -13,7 +13,7 @@ from django.utils.translation import gettext as _ from structlog.stdlib import get_logger from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection -from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostUserEnrollmentStage +from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostSourceStage from authentik.events.models import Event, EventAction from authentik.flows.exceptions import FlowNonApplicableException from authentik.flows.models import Flow, FlowToken, Stage, in_memory_stage @@ -206,13 +206,9 @@ class SourceFlowManager: def get_stages_to_append(self, flow: Flow) -> list[Stage]: """Hook to override stages which are appended to the flow""" - if not self.source.enrollment_flow: - return [] - if flow.slug == self.source.enrollment_flow.slug: - return [ - in_memory_stage(PostUserEnrollmentStage), - ] - return [] + return [ + in_memory_stage(PostSourceStage), + ] def _prepare_flow( self, @@ -266,6 +262,8 @@ class SourceFlowManager: ) # We run the Flow planner here so we can pass the Pending user in the context planner = FlowPlanner(flow) + # We append some stages so the initial flow we get might be empty + planner.allow_empty_flows = True plan = planner.plan(self.request, kwargs) for stage in self.get_stages_to_append(flow): plan.append_stage(stage) @@ -324,7 +322,7 @@ class SourceFlowManager: reverse( "authentik_core:if-user", ) - + f"#/settings;page-{self.source.slug}" + + "#/settings;page-sources" ) def handle_enroll( diff --git a/authentik/core/sources/stage.py b/authentik/core/sources/stage.py index 12703906ca..de863f1982 100644 --- a/authentik/core/sources/stage.py +++ b/authentik/core/sources/stage.py @@ -10,7 +10,7 @@ from authentik.flows.stage import StageView PLAN_CONTEXT_SOURCES_CONNECTION = "goauthentik.io/sources/connection" -class PostUserEnrollmentStage(StageView): +class PostSourceStage(StageView): """Dynamically injected stage which saves the Connection after the user has been enrolled.""" @@ -21,10 +21,12 @@ class PostUserEnrollmentStage(StageView): ] user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] connection.user = user + linked = connection.pk is None connection.save() - Event.new( - EventAction.SOURCE_LINKED, - message="Linked Source", - source=connection.source, - ).from_http(self.request) + if linked: + Event.new( + EventAction.SOURCE_LINKED, + message="Linked Source", + source=connection.source, + ).from_http(self.request) return self.executor.stage_ok() diff --git a/authentik/core/tests/test_source_flow_manager.py b/authentik/core/tests/test_source_flow_manager.py index e6614933d4..5b75ec7859 100644 --- a/authentik/core/tests/test_source_flow_manager.py +++ b/authentik/core/tests/test_source_flow_manager.py @@ -2,11 +2,15 @@ from django.contrib.auth.models import AnonymousUser from django.test import TestCase +from django.urls import reverse from guardian.utils import get_anonymous_user from authentik.core.models import SourceUserMatchingModes, User from authentik.core.sources.flow_manager import Action +from authentik.core.sources.stage import PostSourceStage from authentik.core.tests.utils import create_test_flow +from authentik.flows.planner import FlowPlan +from authentik.flows.views.executor import SESSION_KEY_PLAN from authentik.lib.generators import generate_id from authentik.lib.tests.utils import get_request from authentik.policies.denied import AccessDeniedResponse @@ -21,41 +25,55 @@ class TestSourceFlowManager(TestCase): def setUp(self) -> None: super().setUp() - self.source: OAuthSource = OAuthSource.objects.create(name="test") + self.authentication_flow = create_test_flow() + self.enrollment_flow = create_test_flow() + self.source: OAuthSource = OAuthSource.objects.create( + name=generate_id(), + slug=generate_id(), + authentication_flow=self.authentication_flow, + enrollment_flow=self.enrollment_flow, + ) self.identifier = generate_id() def test_unauthenticated_enroll(self): """Test un-authenticated user enrolling""" - flow_manager = OAuthSourceFlowManager( - self.source, get_request("/", user=AnonymousUser()), self.identifier, {} - ) + request = get_request("/", user=AnonymousUser()) + flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {}) action, _ = flow_manager.get_action() self.assertEqual(action, Action.ENROLL) - flow_manager.get_flow() + response = flow_manager.get_flow() + self.assertEqual(response.status_code, 302) + flow_plan: FlowPlan = request.session[SESSION_KEY_PLAN] + self.assertEqual(flow_plan.bindings[0].stage.view, PostSourceStage) def test_unauthenticated_auth(self): """Test un-authenticated user authenticating""" UserOAuthSourceConnection.objects.create( user=get_anonymous_user(), source=self.source, identifier=self.identifier ) - - flow_manager = OAuthSourceFlowManager( - self.source, get_request("/", user=AnonymousUser()), self.identifier, {} - ) + request = get_request("/", user=AnonymousUser()) + flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {}) action, _ = flow_manager.get_action() self.assertEqual(action, Action.AUTH) - flow_manager.get_flow() + response = flow_manager.get_flow() + self.assertEqual(response.status_code, 302) + flow_plan: FlowPlan = request.session[SESSION_KEY_PLAN] + self.assertEqual(flow_plan.bindings[0].stage.view, PostSourceStage) def test_authenticated_link(self): """Test authenticated user linking""" user = User.objects.create(username="foo", email="foo@bar.baz") - flow_manager = OAuthSourceFlowManager( - self.source, get_request("/", user=user), self.identifier, {} - ) + request = get_request("/", user=user) + flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {}) action, connection = flow_manager.get_action() self.assertEqual(action, Action.LINK) self.assertIsNone(connection.pk) - flow_manager.get_flow() + response = flow_manager.get_flow() + self.assertEqual(response.status_code, 302) + self.assertEqual( + response.url, + reverse("authentik_core:if-user") + "#/settings;page-sources", + ) def test_unauthenticated_link(self): """Test un-authenticated user linking"""