core: fix source flow_manager not always appending save stage (#9659)
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@ -13,7 +13,7 @@ from django.utils.translation import gettext as _
|
|||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection
|
from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection
|
||||||
from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostUserEnrollmentStage
|
from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostSourceStage
|
||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.flows.exceptions import FlowNonApplicableException
|
from authentik.flows.exceptions import FlowNonApplicableException
|
||||||
from authentik.flows.models import Flow, FlowToken, Stage, in_memory_stage
|
from authentik.flows.models import Flow, FlowToken, Stage, in_memory_stage
|
||||||
@ -206,13 +206,9 @@ class SourceFlowManager:
|
|||||||
|
|
||||||
def get_stages_to_append(self, flow: Flow) -> list[Stage]:
|
def get_stages_to_append(self, flow: Flow) -> list[Stage]:
|
||||||
"""Hook to override stages which are appended to the flow"""
|
"""Hook to override stages which are appended to the flow"""
|
||||||
if not self.source.enrollment_flow:
|
return [
|
||||||
return []
|
in_memory_stage(PostSourceStage),
|
||||||
if flow.slug == self.source.enrollment_flow.slug:
|
]
|
||||||
return [
|
|
||||||
in_memory_stage(PostUserEnrollmentStage),
|
|
||||||
]
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _prepare_flow(
|
def _prepare_flow(
|
||||||
self,
|
self,
|
||||||
@ -266,6 +262,8 @@ class SourceFlowManager:
|
|||||||
)
|
)
|
||||||
# We run the Flow planner here so we can pass the Pending user in the context
|
# We run the Flow planner here so we can pass the Pending user in the context
|
||||||
planner = FlowPlanner(flow)
|
planner = FlowPlanner(flow)
|
||||||
|
# We append some stages so the initial flow we get might be empty
|
||||||
|
planner.allow_empty_flows = True
|
||||||
plan = planner.plan(self.request, kwargs)
|
plan = planner.plan(self.request, kwargs)
|
||||||
for stage in self.get_stages_to_append(flow):
|
for stage in self.get_stages_to_append(flow):
|
||||||
plan.append_stage(stage)
|
plan.append_stage(stage)
|
||||||
@ -324,7 +322,7 @@ class SourceFlowManager:
|
|||||||
reverse(
|
reverse(
|
||||||
"authentik_core:if-user",
|
"authentik_core:if-user",
|
||||||
)
|
)
|
||||||
+ f"#/settings;page-{self.source.slug}"
|
+ "#/settings;page-sources"
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle_enroll(
|
def handle_enroll(
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from authentik.flows.stage import StageView
|
|||||||
PLAN_CONTEXT_SOURCES_CONNECTION = "goauthentik.io/sources/connection"
|
PLAN_CONTEXT_SOURCES_CONNECTION = "goauthentik.io/sources/connection"
|
||||||
|
|
||||||
|
|
||||||
class PostUserEnrollmentStage(StageView):
|
class PostSourceStage(StageView):
|
||||||
"""Dynamically injected stage which saves the Connection after
|
"""Dynamically injected stage which saves the Connection after
|
||||||
the user has been enrolled."""
|
the user has been enrolled."""
|
||||||
|
|
||||||
@ -21,10 +21,12 @@ class PostUserEnrollmentStage(StageView):
|
|||||||
]
|
]
|
||||||
user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
|
user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
|
||||||
connection.user = user
|
connection.user = user
|
||||||
|
linked = connection.pk is None
|
||||||
connection.save()
|
connection.save()
|
||||||
Event.new(
|
if linked:
|
||||||
EventAction.SOURCE_LINKED,
|
Event.new(
|
||||||
message="Linked Source",
|
EventAction.SOURCE_LINKED,
|
||||||
source=connection.source,
|
message="Linked Source",
|
||||||
).from_http(self.request)
|
source=connection.source,
|
||||||
|
).from_http(self.request)
|
||||||
return self.executor.stage_ok()
|
return self.executor.stage_ok()
|
||||||
|
|||||||
@ -2,11 +2,15 @@
|
|||||||
|
|
||||||
from django.contrib.auth.models import AnonymousUser
|
from django.contrib.auth.models import AnonymousUser
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
from django.urls import reverse
|
||||||
from guardian.utils import get_anonymous_user
|
from guardian.utils import get_anonymous_user
|
||||||
|
|
||||||
from authentik.core.models import SourceUserMatchingModes, User
|
from authentik.core.models import SourceUserMatchingModes, User
|
||||||
from authentik.core.sources.flow_manager import Action
|
from authentik.core.sources.flow_manager import Action
|
||||||
|
from authentik.core.sources.stage import PostSourceStage
|
||||||
from authentik.core.tests.utils import create_test_flow
|
from authentik.core.tests.utils import create_test_flow
|
||||||
|
from authentik.flows.planner import FlowPlan
|
||||||
|
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||||
from authentik.lib.generators import generate_id
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.lib.tests.utils import get_request
|
from authentik.lib.tests.utils import get_request
|
||||||
from authentik.policies.denied import AccessDeniedResponse
|
from authentik.policies.denied import AccessDeniedResponse
|
||||||
@ -21,41 +25,55 @@ class TestSourceFlowManager(TestCase):
|
|||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.source: OAuthSource = OAuthSource.objects.create(name="test")
|
self.authentication_flow = create_test_flow()
|
||||||
|
self.enrollment_flow = create_test_flow()
|
||||||
|
self.source: OAuthSource = OAuthSource.objects.create(
|
||||||
|
name=generate_id(),
|
||||||
|
slug=generate_id(),
|
||||||
|
authentication_flow=self.authentication_flow,
|
||||||
|
enrollment_flow=self.enrollment_flow,
|
||||||
|
)
|
||||||
self.identifier = generate_id()
|
self.identifier = generate_id()
|
||||||
|
|
||||||
def test_unauthenticated_enroll(self):
|
def test_unauthenticated_enroll(self):
|
||||||
"""Test un-authenticated user enrolling"""
|
"""Test un-authenticated user enrolling"""
|
||||||
flow_manager = OAuthSourceFlowManager(
|
request = get_request("/", user=AnonymousUser())
|
||||||
self.source, get_request("/", user=AnonymousUser()), self.identifier, {}
|
flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {})
|
||||||
)
|
|
||||||
action, _ = flow_manager.get_action()
|
action, _ = flow_manager.get_action()
|
||||||
self.assertEqual(action, Action.ENROLL)
|
self.assertEqual(action, Action.ENROLL)
|
||||||
flow_manager.get_flow()
|
response = flow_manager.get_flow()
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
flow_plan: FlowPlan = request.session[SESSION_KEY_PLAN]
|
||||||
|
self.assertEqual(flow_plan.bindings[0].stage.view, PostSourceStage)
|
||||||
|
|
||||||
def test_unauthenticated_auth(self):
|
def test_unauthenticated_auth(self):
|
||||||
"""Test un-authenticated user authenticating"""
|
"""Test un-authenticated user authenticating"""
|
||||||
UserOAuthSourceConnection.objects.create(
|
UserOAuthSourceConnection.objects.create(
|
||||||
user=get_anonymous_user(), source=self.source, identifier=self.identifier
|
user=get_anonymous_user(), source=self.source, identifier=self.identifier
|
||||||
)
|
)
|
||||||
|
request = get_request("/", user=AnonymousUser())
|
||||||
flow_manager = OAuthSourceFlowManager(
|
flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {})
|
||||||
self.source, get_request("/", user=AnonymousUser()), self.identifier, {}
|
|
||||||
)
|
|
||||||
action, _ = flow_manager.get_action()
|
action, _ = flow_manager.get_action()
|
||||||
self.assertEqual(action, Action.AUTH)
|
self.assertEqual(action, Action.AUTH)
|
||||||
flow_manager.get_flow()
|
response = flow_manager.get_flow()
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
flow_plan: FlowPlan = request.session[SESSION_KEY_PLAN]
|
||||||
|
self.assertEqual(flow_plan.bindings[0].stage.view, PostSourceStage)
|
||||||
|
|
||||||
def test_authenticated_link(self):
|
def test_authenticated_link(self):
|
||||||
"""Test authenticated user linking"""
|
"""Test authenticated user linking"""
|
||||||
user = User.objects.create(username="foo", email="foo@bar.baz")
|
user = User.objects.create(username="foo", email="foo@bar.baz")
|
||||||
flow_manager = OAuthSourceFlowManager(
|
request = get_request("/", user=user)
|
||||||
self.source, get_request("/", user=user), self.identifier, {}
|
flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {})
|
||||||
)
|
|
||||||
action, connection = flow_manager.get_action()
|
action, connection = flow_manager.get_action()
|
||||||
self.assertEqual(action, Action.LINK)
|
self.assertEqual(action, Action.LINK)
|
||||||
self.assertIsNone(connection.pk)
|
self.assertIsNone(connection.pk)
|
||||||
flow_manager.get_flow()
|
response = flow_manager.get_flow()
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertEqual(
|
||||||
|
response.url,
|
||||||
|
reverse("authentik_core:if-user") + "#/settings;page-sources",
|
||||||
|
)
|
||||||
|
|
||||||
def test_unauthenticated_link(self):
|
def test_unauthenticated_link(self):
|
||||||
"""Test un-authenticated user linking"""
|
"""Test un-authenticated user linking"""
|
||||||
|
|||||||
Reference in New Issue
Block a user