From 83b02a17d508c9a3b0edfe1998b1ffebf0ec229d Mon Sep 17 00:00:00 2001 From: Marc 'risson' Schmitt Date: Wed, 7 Aug 2024 19:14:22 +0200 Subject: [PATCH] sources: add property mappings for all oauth and saml sources (#8771) Co-authored-by: Jens L. --- authentik/blueprints/v1/importer.py | 2 + authentik/core/api/sources.py | 42 +- ...matching_mode_alter_group_name_and_more.py | 67 + authentik/core/models.py | 44 +- authentik/core/sources/flow_manager.py | 223 ++- .../core/tests/test_source_flow_manager.py | 52 +- ..._source_flow_manager_group_update_stage.py | 237 +++ .../sources/oauth/api/property_mappings.py | 31 + authentik/sources/oauth/api/source.py | 2 + .../sources/oauth/api/source_connection.py | 18 +- ...008_groupoauthsourceconnection_and_more.py | 60 + authentik/sources/oauth/models.py | 53 +- .../oauth/tests/test_property_mappings.py | 109 ++ .../sources/oauth/tests/test_type_azure_ad.py | 4 +- .../sources/oauth/tests/test_type_discord.py | 4 +- .../sources/oauth/tests/test_type_github.py | 25 +- .../sources/oauth/tests/test_type_gitlab.py | 4 +- .../sources/oauth/tests/test_type_google.py | 7 +- .../sources/oauth/tests/test_type_mailcow.py | 4 +- .../sources/oauth/tests/test_type_openid.py | 6 +- .../sources/oauth/tests/test_type_patreon.py | 7 +- .../sources/oauth/tests/test_type_twitch.py | 4 +- .../sources/oauth/tests/test_type_twitter.py | 4 +- authentik/sources/oauth/types/apple.py | 15 +- authentik/sources/oauth/types/azure_ad.py | 19 +- authentik/sources/oauth/types/discord.py | 17 +- authentik/sources/oauth/types/facebook.py | 17 +- authentik/sources/oauth/types/github.py | 44 +- authentik/sources/oauth/types/gitlab.py | 17 +- authentik/sources/oauth/types/google.py | 15 +- authentik/sources/oauth/types/mailcow.py | 17 +- authentik/sources/oauth/types/oidc.py | 18 +- authentik/sources/oauth/types/okta.py | 18 +- authentik/sources/oauth/types/patreon.py | 17 +- authentik/sources/oauth/types/reddit.py | 18 +- authentik/sources/oauth/types/registry.py | 15 + authentik/sources/oauth/types/twitch.py | 17 +- authentik/sources/oauth/types/twitter.py | 19 +- authentik/sources/oauth/urls.py | 8 +- authentik/sources/oauth/views/callback.py | 30 +- authentik/sources/plex/api/source.py | 3 +- authentik/sources/plex/plex.py | 6 +- .../sources/saml/api/property_mappings.py | 31 + authentik/sources/saml/api/source.py | 1 + .../sources/saml/api/source_connection.py | 18 +- ...rceconnection_samlsourcepropertymapping.py | 57 + authentik/sources/saml/models.py | 91 +- .../sources/saml/processors/constants.py | 2 + authentik/sources/saml/processors/response.py | 72 +- .../fixtures/response_success_groups.xml | 46 + .../saml/tests/test_property_mappings.py | 135 ++ authentik/sources/saml/tests/test_response.py | 11 +- authentik/sources/saml/urls.py | 8 +- blueprints/schema.json | 361 ++++- schema.yml | 1382 ++++++++++++++++- tests/e2e/test_source_oauth_oauth1.py | 17 +- .../PropertyMappingLDAPSourceForm.ts | 4 +- .../PropertyMappingListPage.ts | 2 + .../PropertyMappingOAuthSourceForm.ts | 75 + .../PropertyMappingSAMLSourceForm.ts | 75 + .../PropertyMappingWizard.ts | 2 + .../admin/sources/oauth/OAuthSourceForm.ts | 100 +- web/src/admin/sources/oauth/utils.ts | 18 +- web/src/admin/sources/saml/SAMLSourceForm.ts | 98 +- 64 files changed, 3631 insertions(+), 314 deletions(-) create mode 100644 authentik/core/migrations/0039_source_group_matching_mode_alter_group_name_and_more.py create mode 100644 authentik/core/tests/test_source_flow_manager_group_update_stage.py create mode 100644 authentik/sources/oauth/api/property_mappings.py create mode 100644 authentik/sources/oauth/migrations/0008_groupoauthsourceconnection_and_more.py create mode 100644 authentik/sources/oauth/tests/test_property_mappings.py create mode 100644 authentik/sources/saml/api/property_mappings.py create mode 100644 authentik/sources/saml/migrations/0015_groupsamlsourceconnection_samlsourcepropertymapping.py create mode 100644 authentik/sources/saml/tests/fixtures/response_success_groups.xml create mode 100644 authentik/sources/saml/tests/test_property_mappings.py create mode 100644 web/src/admin/property-mappings/PropertyMappingOAuthSourceForm.ts create mode 100644 web/src/admin/property-mappings/PropertyMappingSAMLSourceForm.ts diff --git a/authentik/blueprints/v1/importer.py b/authentik/blueprints/v1/importer.py index a2377e630f..2143fde053 100644 --- a/authentik/blueprints/v1/importer.py +++ b/authentik/blueprints/v1/importer.py @@ -33,6 +33,7 @@ from authentik.blueprints.v1.common import ( from authentik.blueprints.v1.meta.registry import BaseMetaModel, registry from authentik.core.models import ( AuthenticatedSession, + GroupSourceConnection, PropertyMapping, Provider, Source, @@ -91,6 +92,7 @@ def excluded_models() -> list[type[Model]]: Source, PropertyMapping, UserSourceConnection, + GroupSourceConnection, Stage, OutpostServiceConnection, Policy, diff --git a/authentik/core/api/sources.py b/authentik/core/api/sources.py index 7a3212b7f3..015cbd52b9 100644 --- a/authentik/core/api/sources.py +++ b/authentik/core/api/sources.py @@ -19,7 +19,7 @@ from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT from authentik.core.api.object_types import TypesMixin from authentik.core.api.used_by import UsedByMixin from authentik.core.api.utils import MetaNameSerializer, ModelSerializer -from authentik.core.models import Source, UserSourceConnection +from authentik.core.models import GroupSourceConnection, Source, UserSourceConnection from authentik.core.types import UserSettingSerializer from authentik.lib.utils.file import ( FilePathSerializer, @@ -194,3 +194,43 @@ class UserSourceConnectionViewSet( search_fields = ["source__slug"] filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter] ordering = ["source__slug", "pk"] + + +class GroupSourceConnectionSerializer(SourceSerializer): + """Group Source Connection Serializer""" + + source = SourceSerializer(read_only=True) + + class Meta: + model = GroupSourceConnection + fields = [ + "pk", + "group", + "source", + "identifier", + "created", + ] + extra_kwargs = { + "group": {"read_only": True}, + "identifier": {"read_only": True}, + "created": {"read_only": True}, + } + + +class GroupSourceConnectionViewSet( + mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + mixins.DestroyModelMixin, + UsedByMixin, + mixins.ListModelMixin, + GenericViewSet, +): + """Group-source connection Viewset""" + + queryset = GroupSourceConnection.objects.all() + serializer_class = GroupSourceConnectionSerializer + permission_classes = [OwnerSuperuserPermissions] + filterset_fields = ["group", "source__slug"] + search_fields = ["source__slug"] + filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter] + ordering = ["source__slug", "pk"] diff --git a/authentik/core/migrations/0039_source_group_matching_mode_alter_group_name_and_more.py b/authentik/core/migrations/0039_source_group_matching_mode_alter_group_name_and_more.py new file mode 100644 index 0000000000..5c7f64fc86 --- /dev/null +++ b/authentik/core/migrations/0039_source_group_matching_mode_alter_group_name_and_more.py @@ -0,0 +1,67 @@ +# Generated by Django 5.0.7 on 2024-08-01 18:52 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("authentik_core", "0038_source_authentik_c_enabled_d72365_idx"), + ] + + operations = [ + migrations.AddField( + model_name="source", + name="group_matching_mode", + field=models.TextField( + choices=[ + ("identifier", "Use the source-specific identifier"), + ( + "name_link", + "Link to a group with identical name. Can have security implications when a group name is used with another source.", + ), + ( + "name_deny", + "Use the group name, but deny enrollment when the name already exists.", + ), + ], + default="identifier", + help_text="How the source determines if an existing group should be used or a new group created.", + ), + ), + migrations.AlterField( + model_name="group", + name="name", + field=models.TextField(verbose_name="name"), + ), + migrations.CreateModel( + name="GroupSourceConnection", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name="ID" + ), + ), + ("created", models.DateTimeField(auto_now_add=True)), + ("last_updated", models.DateTimeField(auto_now=True)), + ("identifier", models.TextField()), + ( + "group", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="authentik_core.group" + ), + ), + ( + "source", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="authentik_core.source" + ), + ), + ], + options={ + "unique_together": {("group", "source")}, + }, + ), + ] diff --git a/authentik/core/models.py b/authentik/core/models.py index deaa9df923..1bcbe64746 100644 --- a/authentik/core/models.py +++ b/authentik/core/models.py @@ -173,7 +173,7 @@ class Group(SerializerModel, AttributesMixin): group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) - name = models.CharField(_("name"), max_length=80) + name = models.TextField(_("name")) is_superuser = models.BooleanField( default=False, help_text=_("Users added to this group will be superusers.") ) @@ -583,6 +583,19 @@ class SourceUserMatchingModes(models.TextChoices): ) +class SourceGroupMatchingModes(models.TextChoices): + """Different modes a source can handle new/returning groups""" + + IDENTIFIER = "identifier", _("Use the source-specific identifier") + NAME_LINK = "name_link", _( + "Link to a group with identical name. Can have security implications " + "when a group name is used with another source." + ) + NAME_DENY = "name_deny", _( + "Use the group name, but deny enrollment when the name already exists." + ) + + class Source(ManagedModel, SerializerModel, PolicyBindingModel): """Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server""" @@ -632,6 +645,14 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): "a new user enrolled." ), ) + group_matching_mode = models.TextField( + choices=SourceGroupMatchingModes.choices, + default=SourceGroupMatchingModes.IDENTIFIER, + help_text=_( + "How the source determines if an existing group should be used or " + "a new group created." + ), + ) objects = InheritanceManager() @@ -727,6 +748,27 @@ class UserSourceConnection(SerializerModel, CreatedUpdatedModel): unique_together = (("user", "source"),) +class GroupSourceConnection(SerializerModel, CreatedUpdatedModel): + """Connection between Group and Source.""" + + group = models.ForeignKey(Group, on_delete=models.CASCADE) + source = models.ForeignKey(Source, on_delete=models.CASCADE) + identifier = models.TextField() + + objects = InheritanceManager() + + @property + def serializer(self) -> type[Serializer]: + """Get serializer for this model""" + raise NotImplementedError + + def __str__(self) -> str: + return f"Group-source connection (group={self.group_id}, source={self.source_id})" + + class Meta: + unique_together = (("group", "source"),) + + class ExpiringModel(models.Model): """Base Model which can expire, and is automatically cleaned up.""" diff --git a/authentik/core/sources/flow_manager.py b/authentik/core/sources/flow_manager.py index 9391efda12..86b78d47ef 100644 --- a/authentik/core/sources/flow_manager.py +++ b/authentik/core/sources/flow_manager.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Any from django.contrib import messages -from django.db import IntegrityError +from django.db import IntegrityError, transaction from django.db.models.query_utils import Q from django.http import HttpRequest, HttpResponse from django.shortcuts import redirect @@ -12,8 +12,20 @@ from django.urls import reverse from django.utils.translation import gettext as _ from structlog.stdlib import get_logger -from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection -from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostSourceStage +from authentik.core.models import ( + Group, + GroupSourceConnection, + Source, + SourceGroupMatchingModes, + SourceUserMatchingModes, + User, + UserSourceConnection, +) +from authentik.core.sources.mapper import SourceMapper +from authentik.core.sources.stage import ( + PLAN_CONTEXT_SOURCES_CONNECTION, + PostSourceStage, +) from authentik.events.models import Event, EventAction from authentik.flows.exceptions import FlowNonApplicableException from authentik.flows.models import Flow, FlowToken, Stage, in_memory_stage @@ -36,7 +48,10 @@ from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT from authentik.stages.user_write.stage import PLAN_CONTEXT_USER_PATH +LOGGER = get_logger() + SESSION_KEY_OVERRIDE_FLOW_TOKEN = "authentik/flows/source_override_flow_token" # nosec +PLAN_CONTEXT_SOURCE_GROUPS = "source_groups" class Action(Enum): @@ -70,48 +85,69 @@ class SourceFlowManager: or deny the request.""" source: Source + mapper: SourceMapper request: HttpRequest identifier: str - connection_type: type[UserSourceConnection] = UserSourceConnection + user_connection_type: type[UserSourceConnection] = UserSourceConnection + group_connection_type: type[GroupSourceConnection] = GroupSourceConnection - enroll_info: dict[str, Any] + user_info: dict[str, Any] policy_context: dict[str, Any] + user_properties: dict[str, Any | dict[str, Any]] + groups_properties: dict[str, dict[str, Any | dict[str, Any]]] def __init__( self, source: Source, request: HttpRequest, identifier: str, - enroll_info: dict[str, Any], + user_info: dict[str, Any], + policy_context: dict[str, Any], ) -> None: self.source = source + self.mapper = SourceMapper(self.source) self.request = request self.identifier = identifier - self.enroll_info = enroll_info + self.user_info = user_info self._logger = get_logger().bind(source=source, identifier=identifier) - self.policy_context = {} + self.policy_context = policy_context + + self.user_properties = self.mapper.build_object_properties( + object_type=User, request=request, user=None, **self.user_info + ) + self.groups_properties = { + group_id: self.mapper.build_object_properties( + object_type=Group, + request=request, + user=None, + group_id=group_id, + **self.user_info, + ) + for group_id in self.user_properties.setdefault("groups", []) + } + del self.user_properties["groups"] def get_action(self, **kwargs) -> tuple[Action, UserSourceConnection | None]: # noqa: PLR0911 """decide which action should be taken""" - new_connection = self.connection_type(source=self.source, identifier=self.identifier) + new_connection = self.user_connection_type(source=self.source, identifier=self.identifier) # When request is authenticated, always link if self.request.user.is_authenticated: new_connection.user = self.request.user - new_connection = self.update_connection(new_connection, **kwargs) + new_connection = self.update_user_connection(new_connection, **kwargs) return Action.LINK, new_connection - existing_connections = self.connection_type.objects.filter( + existing_connections = self.user_connection_type.objects.filter( source=self.source, identifier=self.identifier ) if existing_connections.exists(): connection = existing_connections.first() - return Action.AUTH, self.update_connection(connection, **kwargs) + return Action.AUTH, self.update_user_connection(connection, **kwargs) # No connection exists, but we match on identifier, so enroll if self.source.user_matching_mode == SourceUserMatchingModes.IDENTIFIER: # We don't save the connection here cause it doesn't have a user assigned yet - return Action.ENROLL, self.update_connection(new_connection, **kwargs) + return Action.ENROLL, self.update_user_connection(new_connection, **kwargs) # Check for existing users with matching attributes query = Q() @@ -120,24 +156,24 @@ class SourceFlowManager: SourceUserMatchingModes.EMAIL_LINK, SourceUserMatchingModes.EMAIL_DENY, ]: - if not self.enroll_info.get("email", None): - self._logger.warning("Refusing to use none email", source=self.source) + if not self.user_properties.get("email", None): + self._logger.warning("Refusing to use none email") return Action.DENY, None - query = Q(email__exact=self.enroll_info.get("email", None)) + query = Q(email__exact=self.user_properties.get("email", None)) if self.source.user_matching_mode in [ SourceUserMatchingModes.USERNAME_LINK, SourceUserMatchingModes.USERNAME_DENY, ]: - if not self.enroll_info.get("username", None): - self._logger.warning("Refusing to use none username", source=self.source) + if not self.user_properties.get("username", None): + self._logger.warning("Refusing to use none username") return Action.DENY, None - query = Q(username__exact=self.enroll_info.get("username", None)) + query = Q(username__exact=self.user_properties.get("username", None)) self._logger.debug("trying to link with existing user", query=query) matching_users = User.objects.filter(query) # No matching users, always enroll if not matching_users.exists(): self._logger.debug("no matching users found, enrolling") - return Action.ENROLL, self.update_connection(new_connection, **kwargs) + return Action.ENROLL, self.update_user_connection(new_connection, **kwargs) user = matching_users.first() if self.source.user_matching_mode in [ @@ -145,7 +181,7 @@ class SourceFlowManager: SourceUserMatchingModes.USERNAME_LINK, ]: new_connection.user = user - new_connection = self.update_connection(new_connection, **kwargs) + new_connection = self.update_user_connection(new_connection, **kwargs) return Action.LINK, new_connection if self.source.user_matching_mode in [ SourceUserMatchingModes.EMAIL_DENY, @@ -156,10 +192,10 @@ class SourceFlowManager: # Should never get here as default enroll case is returned above. return Action.DENY, None # pragma: no cover - def update_connection( + def update_user_connection( self, connection: UserSourceConnection, **kwargs ) -> UserSourceConnection: # pragma: no cover - """Optionally make changes to the connection after it is looked up/created.""" + """Optionally make changes to the user connection after it is looked up/created.""" return connection def get_flow(self, **kwargs) -> HttpResponse: @@ -215,25 +251,31 @@ class SourceFlowManager: flow: Flow | None, connection: UserSourceConnection, stages: list[StageView] | None = None, - **kwargs, + **flow_context, ) -> HttpResponse: """Prepare Authentication Plan, redirect user FlowExecutor""" - kwargs.update( + # Ensure redirect is carried through when user was trying to + # authorize application + final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get( + NEXT_ARG_NAME, "authentik_core:if-user" + ) + flow_context.update( { # Since we authenticate the user by their token, they have no backend set PLAN_CONTEXT_AUTHENTICATION_BACKEND: BACKEND_INBUILT, PLAN_CONTEXT_SSO: True, PLAN_CONTEXT_SOURCE: self.source, PLAN_CONTEXT_SOURCES_CONNECTION: connection, + PLAN_CONTEXT_SOURCE_GROUPS: self.groups_properties, } ) - kwargs.update(self.policy_context) + flow_context.update(self.policy_context) if SESSION_KEY_OVERRIDE_FLOW_TOKEN in self.request.session: token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN) self._logger.info("Replacing source flow with overridden flow", flow=token.flow.slug) plan = token.plan plan.context[PLAN_CONTEXT_IS_RESTORED] = token - plan.context.update(kwargs) + plan.context.update(flow_context) for stage in self.get_stages_to_append(flow): plan.append_stage(stage) if stages: @@ -252,8 +294,8 @@ class SourceFlowManager: final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get( NEXT_ARG_NAME, "authentik_core:if-user" ) - if PLAN_CONTEXT_REDIRECT not in kwargs: - kwargs[PLAN_CONTEXT_REDIRECT] = final_redirect + if PLAN_CONTEXT_REDIRECT not in flow_context: + flow_context[PLAN_CONTEXT_REDIRECT] = final_redirect if not flow: return bad_request_message( @@ -265,9 +307,12 @@ class SourceFlowManager: # We append some stages so the initial flow we get might be empty planner.allow_empty_flows = True planner.use_cache = False - plan = planner.plan(self.request, kwargs) + plan = planner.plan(self.request, flow_context) for stage in self.get_stages_to_append(flow): plan.append_stage(stage) + plan.append_stage( + in_memory_stage(GroupUpdateStage, group_connection_type=self.group_connection_type) + ) if stages: for stage in stages: plan.append_stage(stage) @@ -354,7 +399,123 @@ class SourceFlowManager: ) ], **{ - PLAN_CONTEXT_PROMPT: delete_none_values(self.enroll_info), + PLAN_CONTEXT_PROMPT: delete_none_values(self.user_properties), PLAN_CONTEXT_USER_PATH: self.source.get_user_path(), }, ) + + +class GroupUpdateStage(StageView): + """Dynamically injected stage which updates the user after enrollment/authentication.""" + + def get_action( + self, group_id: str, group_properties: dict[str, Any | dict[str, Any]] + ) -> tuple[Action, GroupSourceConnection | None]: + """decide which action should be taken""" + new_connection = self.group_connection_type(source=self.source, identifier=group_id) + + existing_connections = self.group_connection_type.objects.filter( + source=self.source, identifier=group_id + ) + if existing_connections.exists(): + return Action.LINK, existing_connections.first() + # No connection exists, but we match on identifier, so enroll + if self.source.group_matching_mode == SourceGroupMatchingModes.IDENTIFIER: + # We don't save the connection here cause it doesn't have a user assigned yet + return Action.ENROLL, new_connection + + # Check for existing groups with matching attributes + query = Q() + if self.source.group_matching_mode in [ + SourceGroupMatchingModes.NAME_LINK, + SourceGroupMatchingModes.NAME_DENY, + ]: + if not group_properties.get("name", None): + LOGGER.warning( + "Refusing to use none group name", source=self.source, group_id=group_id + ) + return Action.DENY, None + query = Q(name__exact=group_properties.get("name")) + LOGGER.debug( + "trying to link with existing group", source=self.source, query=query, group_id=group_id + ) + matching_groups = Group.objects.filter(query) + # No matching groups, always enroll + if not matching_groups.exists(): + LOGGER.debug( + "no matching groups found, enrolling", source=self.source, group_id=group_id + ) + return Action.ENROLL, new_connection + + group = matching_groups.first() + if self.source.group_matching_mode in [ + SourceGroupMatchingModes.NAME_LINK, + ]: + new_connection.group = group + return Action.LINK, new_connection + if self.source.group_matching_mode in [ + SourceGroupMatchingModes.NAME_DENY, + ]: + LOGGER.info( + "denying source because group exists", + source=self.source, + group=group, + group_id=group_id, + ) + return Action.DENY, None + # Should never get here as default enroll case is returned above. + return Action.DENY, None # pragma: no cover + + def handle_group( + self, group_id: str, group_properties: dict[str, Any | dict[str, Any]] + ) -> Group | None: + action, connection = self.get_action(group_id, group_properties) + if action == Action.ENROLL: + group = Group.objects.create(**group_properties) + connection.group = group + connection.save() + return group + elif action == Action.LINK: + group = connection.group + group.update_attributes(group_properties) + connection.save() + return group + + return None + + def handle_groups(self) -> bool: + self.source: Source = self.executor.plan.context[PLAN_CONTEXT_SOURCE] + self.user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] + self.group_connection_type: GroupSourceConnection = ( + self.executor.current_stage.group_connection_type + ) + + raw_groups: dict[str, dict[str, Any | dict[str, Any]]] = self.executor.plan.context[ + PLAN_CONTEXT_SOURCE_GROUPS + ] + groups: list[Group] = [] + + for group_id, group_properties in raw_groups.items(): + group = self.handle_group(group_id, group_properties) + if not group: + return False + groups.append(group) + + with transaction.atomic(): + self.user.ak_groups.remove( + *self.user.ak_groups.filter(groupsourceconnection__source=self.source) + ) + self.user.ak_groups.add(*groups) + + return True + + def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: + """Stage used after the user has been enrolled to sync their groups from source data""" + if self.handle_groups(): + return self.executor.stage_ok() + else: + return self.executor.stage_invalid("Failed to update groups. Please try again later.") + + def post(self, request: HttpRequest) -> HttpResponse: + """Wrapper for post requests""" + return self.get(request) diff --git a/authentik/core/tests/test_source_flow_manager.py b/authentik/core/tests/test_source_flow_manager.py index 5b75ec7859..bcd38449c6 100644 --- a/authentik/core/tests/test_source_flow_manager.py +++ b/authentik/core/tests/test_source_flow_manager.py @@ -38,7 +38,9 @@ class TestSourceFlowManager(TestCase): def test_unauthenticated_enroll(self): """Test un-authenticated user enrolling""" request = get_request("/", user=AnonymousUser()) - flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {}) + flow_manager = OAuthSourceFlowManager( + self.source, request, self.identifier, {"info": {}}, {} + ) action, _ = flow_manager.get_action() self.assertEqual(action, Action.ENROLL) response = flow_manager.get_flow() @@ -52,7 +54,9 @@ class TestSourceFlowManager(TestCase): user=get_anonymous_user(), source=self.source, identifier=self.identifier ) request = get_request("/", user=AnonymousUser()) - flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {}) + flow_manager = OAuthSourceFlowManager( + self.source, request, self.identifier, {"info": {}}, {} + ) action, _ = flow_manager.get_action() self.assertEqual(action, Action.AUTH) response = flow_manager.get_flow() @@ -64,7 +68,9 @@ class TestSourceFlowManager(TestCase): """Test authenticated user linking""" user = User.objects.create(username="foo", email="foo@bar.baz") request = get_request("/", user=user) - flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {}) + flow_manager = OAuthSourceFlowManager( + self.source, request, self.identifier, {"info": {}}, {} + ) action, connection = flow_manager.get_action() self.assertEqual(action, Action.LINK) self.assertIsNone(connection.pk) @@ -77,7 +83,9 @@ class TestSourceFlowManager(TestCase): def test_unauthenticated_link(self): """Test un-authenticated user linking""" - flow_manager = OAuthSourceFlowManager(self.source, get_request("/"), self.identifier, {}) + flow_manager = OAuthSourceFlowManager( + self.source, get_request("/"), self.identifier, {"info": {}}, {} + ) action, connection = flow_manager.get_action() self.assertEqual(action, Action.LINK) self.assertIsNone(connection.pk) @@ -90,7 +98,7 @@ class TestSourceFlowManager(TestCase): # Without email, deny flow_manager = OAuthSourceFlowManager( - self.source, get_request("/", user=AnonymousUser()), self.identifier, {} + self.source, get_request("/", user=AnonymousUser()), self.identifier, {"info": {}}, {} ) action, _ = flow_manager.get_action() self.assertEqual(action, Action.DENY) @@ -100,7 +108,12 @@ class TestSourceFlowManager(TestCase): self.source, get_request("/", user=AnonymousUser()), self.identifier, - {"email": "foo@bar.baz"}, + { + "info": { + "email": "foo@bar.baz", + }, + }, + {}, ) action, _ = flow_manager.get_action() self.assertEqual(action, Action.LINK) @@ -113,7 +126,7 @@ class TestSourceFlowManager(TestCase): # Without username, deny flow_manager = OAuthSourceFlowManager( - self.source, get_request("/", user=AnonymousUser()), self.identifier, {} + self.source, get_request("/", user=AnonymousUser()), self.identifier, {"info": {}}, {} ) action, _ = flow_manager.get_action() self.assertEqual(action, Action.DENY) @@ -123,7 +136,10 @@ class TestSourceFlowManager(TestCase): self.source, get_request("/", user=AnonymousUser()), self.identifier, - {"username": "foo"}, + { + "info": {"username": "foo"}, + }, + {}, ) action, _ = flow_manager.get_action() self.assertEqual(action, Action.LINK) @@ -140,8 +156,11 @@ class TestSourceFlowManager(TestCase): get_request("/", user=AnonymousUser()), self.identifier, { - "username": "bar", + "info": { + "username": "bar", + }, }, + {}, ) action, _ = flow_manager.get_action() self.assertEqual(action, Action.ENROLL) @@ -151,7 +170,10 @@ class TestSourceFlowManager(TestCase): self.source, get_request("/", user=AnonymousUser()), self.identifier, - {"username": "foo"}, + { + "info": {"username": "foo"}, + }, + {}, ) action, _ = flow_manager.get_action() self.assertEqual(action, Action.DENY) @@ -165,7 +187,10 @@ class TestSourceFlowManager(TestCase): self.source, get_request("/", user=AnonymousUser()), self.identifier, - {"username": "foo"}, + { + "info": {"username": "foo"}, + }, + {}, ) action, _ = flow_manager.get_action() self.assertEqual(action, Action.ENROLL) @@ -191,7 +216,10 @@ class TestSourceFlowManager(TestCase): self.source, get_request("/", user=AnonymousUser()), self.identifier, - {"username": "foo"}, + { + "info": {"username": "foo"}, + }, + {}, ) action, _ = flow_manager.get_action() self.assertEqual(action, Action.ENROLL) diff --git a/authentik/core/tests/test_source_flow_manager_group_update_stage.py b/authentik/core/tests/test_source_flow_manager_group_update_stage.py new file mode 100644 index 0000000000..edc7d49366 --- /dev/null +++ b/authentik/core/tests/test_source_flow_manager_group_update_stage.py @@ -0,0 +1,237 @@ +"""Test Source flow_manager group update stage""" + +from django.test import RequestFactory + +from authentik.core.models import Group, SourceGroupMatchingModes +from authentik.core.sources.flow_manager import PLAN_CONTEXT_SOURCE_GROUPS, GroupUpdateStage +from authentik.core.tests.utils import create_test_admin_user, create_test_flow +from authentik.flows.models import in_memory_stage +from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, PLAN_CONTEXT_SOURCE, FlowPlan +from authentik.flows.tests import FlowTestCase +from authentik.flows.views.executor import FlowExecutorView +from authentik.lib.generators import generate_id +from authentik.sources.oauth.models import GroupOAuthSourceConnection, OAuthSource + + +class TestSourceFlowManager(FlowTestCase): + """Test Source flow_manager group update stage""" + + def setUp(self) -> None: + super().setUp() + self.factory = RequestFactory() + self.authentication_flow = create_test_flow() + self.enrollment_flow = create_test_flow() + self.source: OAuthSource = OAuthSource.objects.create( + name=generate_id(), + slug=generate_id(), + authentication_flow=self.authentication_flow, + enrollment_flow=self.enrollment_flow, + ) + self.identifier = generate_id() + self.user = create_test_admin_user() + + def test_nonexistant_group(self): + request = self.factory.get("/") + stage = GroupUpdateStage( + FlowExecutorView( + current_stage=in_memory_stage( + GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection + ), + plan=FlowPlan( + flow_pk=generate_id(), + context={ + PLAN_CONTEXT_SOURCE: self.source, + PLAN_CONTEXT_PENDING_USER: self.user, + PLAN_CONTEXT_SOURCE_GROUPS: { + "group 1": { + "name": "group 1", + }, + }, + }, + ), + ), + request=request, + ) + self.assertTrue(stage.handle_groups()) + self.assertTrue(Group.objects.filter(name="group 1").exists()) + self.assertTrue(self.user.ak_groups.filter(name="group 1").exists()) + self.assertTrue( + GroupOAuthSourceConnection.objects.filter( + group=Group.objects.get(name="group 1"), source=self.source + ).exists() + ) + + def test_nonexistant_group_name_link(self): + self.source.group_matching_mode = SourceGroupMatchingModes.NAME_LINK + self.source.save() + + request = self.factory.get("/") + stage = GroupUpdateStage( + FlowExecutorView( + current_stage=in_memory_stage( + GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection + ), + plan=FlowPlan( + flow_pk=generate_id(), + context={ + PLAN_CONTEXT_SOURCE: self.source, + PLAN_CONTEXT_PENDING_USER: self.user, + PLAN_CONTEXT_SOURCE_GROUPS: { + "group 1": { + "name": "group 1", + }, + }, + }, + ), + ), + request=request, + ) + self.assertTrue(stage.handle_groups()) + self.assertTrue(Group.objects.filter(name="group 1").exists()) + self.assertTrue(self.user.ak_groups.filter(name="group 1").exists()) + self.assertTrue( + GroupOAuthSourceConnection.objects.filter( + group=Group.objects.get(name="group 1"), source=self.source + ).exists() + ) + + def test_existant_group_name_link(self): + self.source.group_matching_mode = SourceGroupMatchingModes.NAME_LINK + self.source.save() + group = Group.objects.create(name="group 1") + + request = self.factory.get("/") + stage = GroupUpdateStage( + FlowExecutorView( + current_stage=in_memory_stage( + GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection + ), + plan=FlowPlan( + flow_pk=generate_id(), + context={ + PLAN_CONTEXT_SOURCE: self.source, + PLAN_CONTEXT_PENDING_USER: self.user, + PLAN_CONTEXT_SOURCE_GROUPS: { + "group 1": { + "name": "group 1", + }, + }, + }, + ), + ), + request=request, + ) + self.assertTrue(stage.handle_groups()) + self.assertTrue(Group.objects.filter(name="group 1").exists()) + self.assertTrue(self.user.ak_groups.filter(name="group 1").exists()) + self.assertTrue( + GroupOAuthSourceConnection.objects.filter(group=group, source=self.source).exists() + ) + + def test_nonexistant_group_name_deny(self): + self.source.group_matching_mode = SourceGroupMatchingModes.NAME_DENY + self.source.save() + + request = self.factory.get("/") + stage = GroupUpdateStage( + FlowExecutorView( + current_stage=in_memory_stage( + GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection + ), + plan=FlowPlan( + flow_pk=generate_id(), + context={ + PLAN_CONTEXT_SOURCE: self.source, + PLAN_CONTEXT_PENDING_USER: self.user, + PLAN_CONTEXT_SOURCE_GROUPS: { + "group 1": { + "name": "group 1", + }, + }, + }, + ), + ), + request=request, + ) + self.assertTrue(stage.handle_groups()) + self.assertTrue(Group.objects.filter(name="group 1").exists()) + self.assertTrue(self.user.ak_groups.filter(name="group 1").exists()) + self.assertTrue( + GroupOAuthSourceConnection.objects.filter( + group=Group.objects.get(name="group 1"), source=self.source + ).exists() + ) + + def test_existant_group_name_deny(self): + self.source.group_matching_mode = SourceGroupMatchingModes.NAME_DENY + self.source.save() + group = Group.objects.create(name="group 1") + + request = self.factory.get("/") + stage = GroupUpdateStage( + FlowExecutorView( + current_stage=in_memory_stage( + GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection + ), + plan=FlowPlan( + flow_pk=generate_id(), + context={ + PLAN_CONTEXT_SOURCE: self.source, + PLAN_CONTEXT_PENDING_USER: self.user, + PLAN_CONTEXT_SOURCE_GROUPS: { + "group 1": { + "name": "group 1", + }, + }, + }, + ), + ), + request=request, + ) + self.assertFalse(stage.handle_groups()) + self.assertFalse(self.user.ak_groups.filter(name="group 1").exists()) + self.assertFalse( + GroupOAuthSourceConnection.objects.filter(group=group, source=self.source).exists() + ) + + def test_group_updates(self): + self.source.group_matching_mode = SourceGroupMatchingModes.NAME_LINK + self.source.save() + + other_group = Group.objects.create(name="other group") + old_group = Group.objects.create(name="old group") + new_group = Group.objects.create(name="new group") + self.user.ak_groups.set([other_group, old_group]) + GroupOAuthSourceConnection.objects.create( + group=old_group, source=self.source, identifier=old_group.name + ) + GroupOAuthSourceConnection.objects.create( + group=new_group, source=self.source, identifier=new_group.name + ) + + request = self.factory.get("/") + stage = GroupUpdateStage( + FlowExecutorView( + current_stage=in_memory_stage( + GroupUpdateStage, group_connection_type=GroupOAuthSourceConnection + ), + plan=FlowPlan( + flow_pk=generate_id(), + context={ + PLAN_CONTEXT_SOURCE: self.source, + PLAN_CONTEXT_PENDING_USER: self.user, + PLAN_CONTEXT_SOURCE_GROUPS: { + "new group": { + "name": "new group", + }, + }, + }, + ), + ), + request=request, + ) + self.assertTrue(stage.handle_groups()) + self.assertFalse(self.user.ak_groups.filter(name="old group").exists()) + self.assertTrue(self.user.ak_groups.filter(name="other group").exists()) + self.assertTrue(self.user.ak_groups.filter(name="new group").exists()) + self.assertEqual(self.user.ak_groups.count(), 2) diff --git a/authentik/sources/oauth/api/property_mappings.py b/authentik/sources/oauth/api/property_mappings.py new file mode 100644 index 0000000000..d9e61a8662 --- /dev/null +++ b/authentik/sources/oauth/api/property_mappings.py @@ -0,0 +1,31 @@ +"""OAuth source property mappings API""" + +from rest_framework.viewsets import ModelViewSet + +from authentik.core.api.property_mappings import PropertyMappingFilterSet, PropertyMappingSerializer +from authentik.core.api.used_by import UsedByMixin +from authentik.sources.oauth.models import OAuthSourcePropertyMapping + + +class OAuthSourcePropertyMappingSerializer(PropertyMappingSerializer): + """OAuthSourcePropertyMapping Serializer""" + + class Meta(PropertyMappingSerializer.Meta): + model = OAuthSourcePropertyMapping + + +class OAuthSourcePropertyMappingFilter(PropertyMappingFilterSet): + """Filter for OAuthSourcePropertyMapping""" + + class Meta(PropertyMappingFilterSet.Meta): + model = OAuthSourcePropertyMapping + + +class OAuthSourcePropertyMappingViewSet(UsedByMixin, ModelViewSet): + """OAuthSourcePropertyMapping Viewset""" + + queryset = OAuthSourcePropertyMapping.objects.all() + serializer_class = OAuthSourcePropertyMappingSerializer + filterset_class = OAuthSourcePropertyMappingFilter + search_fields = ["name"] + ordering = ["name"] diff --git a/authentik/sources/oauth/api/source.py b/authentik/sources/oauth/api/source.py index 28de050faf..ebba67d2f2 100644 --- a/authentik/sources/oauth/api/source.py +++ b/authentik/sources/oauth/api/source.py @@ -116,6 +116,7 @@ class OAuthSourceSerializer(SourceSerializer): class Meta: model = OAuthSource fields = SourceSerializer.Meta.fields + [ + "group_matching_mode", "provider_type", "request_token_url", "authorization_url", @@ -158,6 +159,7 @@ class OAuthSourceFilter(FilterSet): "enrollment_flow", "policy_engine_mode", "user_matching_mode", + "group_matching_mode", "provider_type", "request_token_url", "authorization_url", diff --git a/authentik/sources/oauth/api/source_connection.py b/authentik/sources/oauth/api/source_connection.py index b49b230a30..98daad0184 100644 --- a/authentik/sources/oauth/api/source_connection.py +++ b/authentik/sources/oauth/api/source_connection.py @@ -3,10 +3,12 @@ from rest_framework.viewsets import ModelViewSet from authentik.core.api.sources import ( + GroupSourceConnectionSerializer, + GroupSourceConnectionViewSet, UserSourceConnectionSerializer, UserSourceConnectionViewSet, ) -from authentik.sources.oauth.models import UserOAuthSourceConnection +from authentik.sources.oauth.models import GroupOAuthSourceConnection, UserOAuthSourceConnection class UserOAuthSourceConnectionSerializer(UserSourceConnectionSerializer): @@ -26,3 +28,17 @@ class UserOAuthSourceConnectionViewSet(UserSourceConnectionViewSet, ModelViewSet queryset = UserOAuthSourceConnection.objects.all() serializer_class = UserOAuthSourceConnectionSerializer + + +class GroupOAuthSourceConnectionSerializer(GroupSourceConnectionSerializer): + """OAuth Group-Source connection Serializer""" + + class Meta(GroupSourceConnectionSerializer.Meta): + model = GroupOAuthSourceConnection + + +class GroupOAuthSourceConnectionViewSet(GroupSourceConnectionViewSet, ModelViewSet): + """Group-source connection Viewset""" + + queryset = GroupOAuthSourceConnection.objects.all() + serializer_class = GroupOAuthSourceConnectionSerializer diff --git a/authentik/sources/oauth/migrations/0008_groupoauthsourceconnection_and_more.py b/authentik/sources/oauth/migrations/0008_groupoauthsourceconnection_and_more.py new file mode 100644 index 0000000000..f103c710c0 --- /dev/null +++ b/authentik/sources/oauth/migrations/0008_groupoauthsourceconnection_and_more.py @@ -0,0 +1,60 @@ +# Generated by Django 5.0.7 on 2024-08-01 18:52 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("authentik_core", "0039_source_group_matching_mode_alter_group_name_and_more"), + ( + "authentik_sources_oauth", + "0007_oauthsource_oidc_jwks_oauthsource_oidc_jwks_url_and_more", + ), + ] + + operations = [ + migrations.CreateModel( + name="GroupOAuthSourceConnection", + fields=[ + ( + "groupsourceconnection_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="authentik_core.groupsourceconnection", + ), + ), + ], + options={ + "verbose_name": "Group OAuth Source Connection", + "verbose_name_plural": "Group OAuth Source Connections", + }, + bases=("authentik_core.groupsourceconnection",), + ), + migrations.CreateModel( + name="OAuthSourcePropertyMapping", + fields=[ + ( + "propertymapping_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="authentik_core.propertymapping", + ), + ), + ], + options={ + "verbose_name": "OAuth Source Property Mapping", + "verbose_name_plural": "OAuth Source Property Mappings", + }, + bases=("authentik_core.propertymapping",), + ), + ] diff --git a/authentik/sources/oauth/models.py b/authentik/sources/oauth/models.py index d05665bcb4..73ed11b773 100644 --- a/authentik/sources/oauth/models.py +++ b/authentik/sources/oauth/models.py @@ -9,7 +9,12 @@ from django.utils.translation import gettext_lazy as _ from rest_framework.serializers import Serializer from authentik.core.api.object_types import CreatableType, NonCreatableType -from authentik.core.models import Source, UserSourceConnection +from authentik.core.models import ( + GroupSourceConnection, + PropertyMapping, + Source, + UserSourceConnection, +) from authentik.core.types import UILoginButton, UserSettingSerializer if TYPE_CHECKING: @@ -73,6 +78,16 @@ class OAuthSource(NonCreatableType, Source): return OAuthSourceSerializer + @property + def property_mapping_type(self) -> type[PropertyMapping]: + return OAuthSourcePropertyMapping + + def get_base_user_properties(self, **kwargs): + return self.source_type().get_base_user_properties(source=self, **kwargs) + + def get_base_group_properties(self, **kwargs): + return self.source_type().get_base_group_properties(source=self, **kwargs) + @property def icon_url(self) -> str | None: # When listing source types, this property might be retrieved from an abstract @@ -248,6 +263,26 @@ class RedditOAuthSource(CreatableType, OAuthSource): verbose_name_plural = _("Reddit OAuth Sources") +class OAuthSourcePropertyMapping(PropertyMapping): + """Map OAuth properties to User or Group object attributes""" + + @property + def component(self) -> str: + return "ak-property-mapping-oauth-source-form" + + @property + def serializer(self) -> type[Serializer]: + from authentik.sources.oauth.api.property_mappings import ( + OAuthSourcePropertyMappingSerializer, + ) + + return OAuthSourcePropertyMappingSerializer + + class Meta: + verbose_name = _("OAuth Source Property Mapping") + verbose_name_plural = _("OAuth Source Property Mappings") + + class UserOAuthSourceConnection(UserSourceConnection): """Authorized remote OAuth provider.""" @@ -269,3 +304,19 @@ class UserOAuthSourceConnection(UserSourceConnection): class Meta: verbose_name = _("User OAuth Source Connection") verbose_name_plural = _("User OAuth Source Connections") + + +class GroupOAuthSourceConnection(GroupSourceConnection): + """Group-source connection""" + + @property + def serializer(self) -> type[Serializer]: + from authentik.sources.oauth.api.source_connection import ( + GroupOAuthSourceConnectionSerializer, + ) + + return GroupOAuthSourceConnectionSerializer + + class Meta: + verbose_name = _("Group OAuth Source Connection") + verbose_name_plural = _("Group OAuth Source Connections") diff --git a/authentik/sources/oauth/tests/test_property_mappings.py b/authentik/sources/oauth/tests/test_property_mappings.py new file mode 100644 index 0000000000..47cb8ca2f4 --- /dev/null +++ b/authentik/sources/oauth/tests/test_property_mappings.py @@ -0,0 +1,109 @@ +"""Apple Type tests""" + +from copy import deepcopy + +from django.contrib.auth.models import AnonymousUser +from django.test import TestCase + +from authentik.lib.generators import generate_id +from authentik.lib.tests.utils import get_request +from authentik.sources.oauth.models import OAuthSource, OAuthSourcePropertyMapping +from authentik.sources.oauth.views.callback import OAuthSourceFlowManager + +INFO = { + "sub": "83692", + "name": "Alice Adams", + "email": "alice@example.com", + "department": "Engineering", + "birthdate": "1975-12-31", + "nickname": "foo", +} +IDENTIFIER = INFO["sub"] + + +class TestPropertyMappings(TestCase): + """OAuth Source tests""" + + def setUp(self): + self.source = OAuthSource.objects.create( + name="test", + slug="test", + provider_type="openidconnect", + authorization_url="", + profile_url="", + consumer_key=generate_id(), + ) + + def test_user_base_properties(self): + """Test user base properties""" + properties = self.source.get_base_user_properties(info=INFO) + self.assertEqual( + properties, + { + "email": "alice@example.com", + "groups": [], + "name": "Alice Adams", + "username": "foo", + }, + ) + + def test_group_base_properties(self): + """Test group base properties""" + info = deepcopy(INFO) + info["groups"] = ["group 1", "group 2"] + properties = self.source.get_base_user_properties(info=info) + self.assertEqual(properties["groups"], ["group 1", "group 2"]) + for group_id in info["groups"]: + properties = self.source.get_base_group_properties(info=info, group_id=group_id) + self.assertEqual(properties, {"name": group_id}) + + def test_user_property_mappings(self): + self.source.user_property_mappings.add( + OAuthSourcePropertyMapping.objects.create( + name="test", + expression="return {'attributes': {'department': info.get('department')}}", + ) + ) + request = get_request("/", user=AnonymousUser()) + flow_manager = OAuthSourceFlowManager(self.source, request, IDENTIFIER, {"info": INFO}, {}) + self.assertEqual( + flow_manager.user_properties, + { + "attributes": { + "department": "Engineering", + }, + "email": "alice@example.com", + "name": "Alice Adams", + "username": "foo", + "path": self.source.get_user_path(), + }, + ) + + def test_grup_property_mappings(self): + info = deepcopy(INFO) + info["groups"] = ["group 1", "group 2"] + self.source.group_property_mappings.add( + OAuthSourcePropertyMapping.objects.create( + name="test", + expression="return {'attributes': {'id': group_id}}", + ) + ) + request = get_request("/", user=AnonymousUser()) + flow_manager = OAuthSourceFlowManager(self.source, request, IDENTIFIER, {"info": info}, {}) + self.assertEqual( + flow_manager.groups_properties, + { + "group 1": { + "name": "group 1", + "attributes": { + "id": "group 1", + }, + }, + "group 2": { + "name": "group 2", + "attributes": { + "id": "group 2", + }, + }, + }, + ) diff --git a/authentik/sources/oauth/tests/test_type_azure_ad.py b/authentik/sources/oauth/tests/test_type_azure_ad.py index e34892fad3..5762bafbfe 100644 --- a/authentik/sources/oauth/tests/test_type_azure_ad.py +++ b/authentik/sources/oauth/tests/test_type_azure_ad.py @@ -3,7 +3,7 @@ from django.test import TestCase from authentik.sources.oauth.models import OAuthSource -from authentik.sources.oauth.types.azure_ad import AzureADOAuthCallback +from authentik.sources.oauth.types.azure_ad import AzureADOAuthCallback, AzureADType # https://docs.microsoft.com/en-us/graph/api/user-get?view=graph-rest-1.0&tabs=http#response-2 AAD_USER = { @@ -41,7 +41,7 @@ class TestTypeAzureAD(TestCase): def test_enroll_context(self): """Test azure_ad Enrollment context""" - ak_context = AzureADOAuthCallback().get_user_enroll_context(AAD_USER) + ak_context = AzureADType().get_base_user_properties(source=self.source, info=AAD_USER) self.assertEqual(ak_context["username"], AAD_USER["userPrincipalName"]) self.assertEqual(ak_context["email"], AAD_USER["mail"]) self.assertEqual(ak_context["name"], AAD_USER["displayName"]) diff --git a/authentik/sources/oauth/tests/test_type_discord.py b/authentik/sources/oauth/tests/test_type_discord.py index e1c996a2d1..c86019b56e 100644 --- a/authentik/sources/oauth/tests/test_type_discord.py +++ b/authentik/sources/oauth/tests/test_type_discord.py @@ -3,7 +3,7 @@ from django.test import TestCase from authentik.sources.oauth.models import OAuthSource -from authentik.sources.oauth.types.discord import DiscordOAuth2Callback +from authentik.sources.oauth.types.discord import DiscordType # https://discord.com/developers/docs/resources/user#user-object DISCORD_USER = { @@ -34,7 +34,7 @@ class TestTypeDiscord(TestCase): def test_enroll_context(self): """Test discord Enrollment context""" - ak_context = DiscordOAuth2Callback().get_user_enroll_context(DISCORD_USER) + ak_context = DiscordType().get_base_user_properties(source=self.source, info=DISCORD_USER) self.assertEqual(ak_context["username"], DISCORD_USER["username"]) self.assertEqual(ak_context["email"], DISCORD_USER["email"]) self.assertEqual(ak_context["name"], DISCORD_USER["username"]) diff --git a/authentik/sources/oauth/tests/test_type_github.py b/authentik/sources/oauth/tests/test_type_github.py index a24cdaaa75..2e7a42231d 100644 --- a/authentik/sources/oauth/tests/test_type_github.py +++ b/authentik/sources/oauth/tests/test_type_github.py @@ -7,7 +7,10 @@ from requests_mock import Mocker from authentik.lib.generators import generate_id from authentik.sources.oauth.models import OAuthSource -from authentik.sources.oauth.types.github import GitHubOAuth2Callback +from authentik.sources.oauth.types.github import ( + GitHubOAuth2Callback, + GitHubType, +) # https://developer.github.com/v3/users/#get-the-authenticated-user GITHUB_USER = { @@ -66,7 +69,9 @@ class TestTypeGitHub(TestCase): def test_enroll_context(self): """Test GitHub Enrollment context""" - ak_context = GitHubOAuth2Callback().get_user_enroll_context(GITHUB_USER) + ak_context = GitHubType().get_base_user_properties( + source=self.source, info=GITHUB_USER, client=None, token={} + ) self.assertEqual(ak_context["username"], GITHUB_USER["login"]) self.assertEqual(ak_context["email"], GITHUB_USER["email"]) self.assertEqual(ak_context["name"], GITHUB_USER["name"]) @@ -86,14 +91,18 @@ class TestTypeGitHub(TestCase): } ], ) - ak_context = GitHubOAuth2Callback( + token = { + "access_token": generate_id(), + "token_type": generate_id(), + } + callback = GitHubOAuth2Callback( source=self.source, request=self.factory.get("/"), - token={ - "access_token": generate_id(), - "token_type": generate_id(), - }, - ).get_user_enroll_context(user) + token=token, + ) + ak_context = GitHubType().get_base_user_properties( + source=self.source, info=user, client=callback.get_client(self.source), token=token + ) self.assertEqual(ak_context["username"], GITHUB_USER["login"]) self.assertEqual(ak_context["email"], email) self.assertEqual(ak_context["name"], GITHUB_USER["name"]) diff --git a/authentik/sources/oauth/tests/test_type_gitlab.py b/authentik/sources/oauth/tests/test_type_gitlab.py index 99bfa25bae..8d2a5336d4 100644 --- a/authentik/sources/oauth/tests/test_type_gitlab.py +++ b/authentik/sources/oauth/tests/test_type_gitlab.py @@ -3,7 +3,7 @@ from django.test import TestCase from authentik.sources.oauth.models import OAuthSource -from authentik.sources.oauth.types.gitlab import GitLabOAuthCallback +from authentik.sources.oauth.types.gitlab import GitLabType GITLAB_USER = { "preferred_username": "dev_gitlab", @@ -24,7 +24,7 @@ class TestTypeGitLab(TestCase): def test_enroll_context(self): """Test GitLab Enrollment context""" - ak_context = GitLabOAuthCallback().get_user_enroll_context(GITLAB_USER) + ak_context = GitLabType().get_base_user_properties(source=self.source, info=GITLAB_USER) self.assertEqual(ak_context["username"], GITLAB_USER["preferred_username"]) self.assertEqual(ak_context["email"], GITLAB_USER["email"]) self.assertEqual(ak_context["name"], GITLAB_USER["name"]) diff --git a/authentik/sources/oauth/tests/test_type_google.py b/authentik/sources/oauth/tests/test_type_google.py index 0b6d3888e2..3ecd8bbb10 100644 --- a/authentik/sources/oauth/tests/test_type_google.py +++ b/authentik/sources/oauth/tests/test_type_google.py @@ -6,7 +6,10 @@ from django.test.client import RequestFactory from authentik.lib.tests.utils import dummy_get_response from authentik.sources.oauth.models import OAuthSource -from authentik.sources.oauth.types.google import GoogleOAuth2Callback, GoogleOAuthRedirect +from authentik.sources.oauth.types.google import ( + GoogleOAuthRedirect, + GoogleType, +) # https://developers.google.com/identity/protocols/oauth2/openid-connect?hl=en GOOGLE_USER = { @@ -37,7 +40,7 @@ class TestTypeGoogle(TestCase): def test_enroll_context(self): """Test Google Enrollment context""" - ak_context = GoogleOAuth2Callback().get_user_enroll_context(GOOGLE_USER) + ak_context = GoogleType().get_base_user_properties(source=self.source, info=GOOGLE_USER) self.assertEqual(ak_context["email"], GOOGLE_USER["email"]) self.assertEqual(ak_context["name"], GOOGLE_USER["name"]) diff --git a/authentik/sources/oauth/tests/test_type_mailcow.py b/authentik/sources/oauth/tests/test_type_mailcow.py index 8a7a4b30db..9d91d58b99 100644 --- a/authentik/sources/oauth/tests/test_type_mailcow.py +++ b/authentik/sources/oauth/tests/test_type_mailcow.py @@ -3,7 +3,7 @@ from django.test import TestCase from authentik.sources.oauth.models import OAuthSource -from authentik.sources.oauth.types.mailcow import MailcowOAuth2Callback +from authentik.sources.oauth.types.mailcow import MailcowType # https://community.mailcow.email/d/13-mailcow-oauth-json-format/2 MAILCOW_USER = { @@ -34,6 +34,6 @@ class TestTypeMailcow(TestCase): def test_enroll_context(self): """Test mailcow Enrollment context""" - ak_context = MailcowOAuth2Callback().get_user_enroll_context(MAILCOW_USER) + ak_context = MailcowType().get_base_user_properties(source=self.source, info=MAILCOW_USER) self.assertEqual(ak_context["email"], MAILCOW_USER["email"]) self.assertEqual(ak_context["name"], MAILCOW_USER["full_name"]) diff --git a/authentik/sources/oauth/tests/test_type_openid.py b/authentik/sources/oauth/tests/test_type_openid.py index f8c7805489..3b96d5ee04 100644 --- a/authentik/sources/oauth/tests/test_type_openid.py +++ b/authentik/sources/oauth/tests/test_type_openid.py @@ -5,7 +5,7 @@ from requests_mock import Mocker from authentik.lib.generators import generate_id from authentik.sources.oauth.models import OAuthSource -from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback +from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback, OpenIDConnectType # https://connect2id.com/products/server/docs/api/userinfo OPENID_USER = { @@ -34,7 +34,9 @@ class TestTypeOpenID(TestCase): def test_enroll_context(self): """Test OpenID Enrollment context""" - ak_context = OpenIDConnectOAuth2Callback().get_user_enroll_context(OPENID_USER) + ak_context = OpenIDConnectType().get_base_user_properties( + source=self.source, info=OPENID_USER + ) self.assertEqual(ak_context["username"], OPENID_USER["nickname"]) self.assertEqual(ak_context["email"], OPENID_USER["email"]) self.assertEqual(ak_context["name"], OPENID_USER["name"]) diff --git a/authentik/sources/oauth/tests/test_type_patreon.py b/authentik/sources/oauth/tests/test_type_patreon.py index 680df0724c..3dc1a18187 100644 --- a/authentik/sources/oauth/tests/test_type_patreon.py +++ b/authentik/sources/oauth/tests/test_type_patreon.py @@ -1,9 +1,9 @@ """Patreon Type tests""" -from django.test import RequestFactory, TestCase +from django.test import TestCase from authentik.sources.oauth.models import OAuthSource -from authentik.sources.oauth.types.patreon import PatreonOAuthCallback +from authentik.sources.oauth.types.patreon import PatreonType PATREON_USER = { "data": { @@ -58,11 +58,10 @@ class TestTypePatreon(TestCase): slug="test", provider_type="Patreon", ) - self.factory = RequestFactory() def test_enroll_context(self): """Test Patreon Enrollment context""" - ak_context = PatreonOAuthCallback().get_user_enroll_context(PATREON_USER) + ak_context = PatreonType().get_base_user_properties(source=self.source, info=PATREON_USER) self.assertEqual(ak_context["username"], PATREON_USER["data"]["attributes"]["vanity"]) self.assertEqual(ak_context["email"], PATREON_USER["data"]["attributes"]["email"]) self.assertEqual(ak_context["name"], PATREON_USER["data"]["attributes"]["full_name"]) diff --git a/authentik/sources/oauth/tests/test_type_twitch.py b/authentik/sources/oauth/tests/test_type_twitch.py index e9fbe4acd7..d49a2c53ed 100644 --- a/authentik/sources/oauth/tests/test_type_twitch.py +++ b/authentik/sources/oauth/tests/test_type_twitch.py @@ -3,7 +3,7 @@ from django.test import TestCase from authentik.sources.oauth.models import OAuthSource -from authentik.sources.oauth.types.twitch import TwitchOAuth2Callback +from authentik.sources.oauth.types.twitch import TwitchType # https://dev.twitch.tv/docs/authentication/getting-tokens-oidc/#getting-claims-information-from-an-access-token TWITCH_USER = { @@ -32,7 +32,7 @@ class TestTypeTwitch(TestCase): def test_enroll_context(self): """Test twitch Enrollment context""" - ak_context = TwitchOAuth2Callback().get_user_enroll_context(TWITCH_USER) + ak_context = TwitchType().get_base_user_properties(source=self.source, info=TWITCH_USER) self.assertEqual(ak_context["username"], TWITCH_USER["preferred_username"]) self.assertEqual(ak_context["email"], TWITCH_USER["email"]) self.assertEqual(ak_context["name"], TWITCH_USER["preferred_username"]) diff --git a/authentik/sources/oauth/tests/test_type_twitter.py b/authentik/sources/oauth/tests/test_type_twitter.py index 8ff9f2ffd5..4450aa7b2f 100644 --- a/authentik/sources/oauth/tests/test_type_twitter.py +++ b/authentik/sources/oauth/tests/test_type_twitter.py @@ -3,7 +3,7 @@ from django.test import TestCase from authentik.sources.oauth.models import OAuthSource -from authentik.sources.oauth.types.twitter import TwitterOAuthCallback +from authentik.sources.oauth.types.twitter import TwitterType # https://developer.twitter.com/en/docs/twitter-api/users/lookup/api-reference/get-users-me TWITTER_USER = {"data": {"id": "2244994945", "name": "TwitterDev", "username": "Twitter Dev"}} @@ -24,7 +24,7 @@ class TestTypeGitHub(TestCase): def test_enroll_context(self): """Test Twitter Enrollment context""" - ak_context = TwitterOAuthCallback().get_user_enroll_context(TWITTER_USER) + ak_context = TwitterType().get_base_user_properties(source=self.source, info=TWITTER_USER) self.assertEqual(ak_context["username"], TWITTER_USER["data"]["username"]) self.assertEqual(ak_context["email"], None) self.assertEqual(ak_context["name"], TWITTER_USER["data"]["name"]) diff --git a/authentik/sources/oauth/types/apple.py b/authentik/sources/oauth/types/apple.py index 3d272b6042..1e3b8a2099 100644 --- a/authentik/sources/oauth/types/apple.py +++ b/authentik/sources/oauth/types/apple.py @@ -90,15 +90,6 @@ class AppleOAuth2Callback(OAuthCallback): def get_user_id(self, info: dict[str, Any]) -> str | None: return info["sub"] - def get_user_enroll_context( - self, - info: dict[str, Any], - ) -> dict[str, Any]: - return { - "email": info.get("email"), - "name": info.get("name"), - } - @registry.register() class AppleType(SourceType): @@ -132,3 +123,9 @@ class AppleType(SourceType): "state": args["state"], } ) + + def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: + return { + "email": info.get("email"), + "name": info.get("name"), + } diff --git a/authentik/sources/oauth/types/azure_ad.py b/authentik/sources/oauth/types/azure_ad.py index 341831138f..7d7f4e1592 100644 --- a/authentik/sources/oauth/types/azure_ad.py +++ b/authentik/sources/oauth/types/azure_ad.py @@ -31,17 +31,6 @@ class AzureADOAuthCallback(OpenIDConnectOAuth2Callback): # fallback to OpenID logic in case the profile URL was changed return info.get("id", super().get_user_id(info)) - def get_user_enroll_context( - self, - info: dict[str, Any], - ) -> dict[str, Any]: - mail = info.get("mail", None) or info.get("otherMails", [None])[0] - return { - "username": info.get("userPrincipalName"), - "email": mail, - "name": info.get("displayName"), - } - @registry.register() class AzureADType(SourceType): @@ -61,3 +50,11 @@ class AzureADType(SourceType): "https://login.microsoftonline.com/common/.well-known/openid-configuration" ) oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys" + + def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: + mail = info.get("mail", None) or info.get("otherMails", [None])[0] + return { + "username": info.get("userPrincipalName"), + "email": mail, + "name": info.get("displayName"), + } diff --git a/authentik/sources/oauth/types/discord.py b/authentik/sources/oauth/types/discord.py index a67c07bf13..815feb710d 100644 --- a/authentik/sources/oauth/types/discord.py +++ b/authentik/sources/oauth/types/discord.py @@ -20,16 +20,6 @@ class DiscordOAuthRedirect(OAuthRedirect): class DiscordOAuth2Callback(OAuthCallback): """Discord OAuth2 Callback""" - def get_user_enroll_context( - self, - info: dict[str, Any], - ) -> dict[str, Any]: - return { - "username": info.get("username"), - "email": info.get("email", None), - "name": info.get("username"), - } - @registry.register() class DiscordType(SourceType): @@ -43,3 +33,10 @@ class DiscordType(SourceType): authorization_url = "https://discord.com/api/oauth2/authorize" access_token_url = "https://discord.com/api/oauth2/token" # nosec profile_url = "https://discord.com/api/users/@me" + + def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: + return { + "username": info.get("username"), + "email": info.get("email", None), + "name": info.get("username"), + } diff --git a/authentik/sources/oauth/types/facebook.py b/authentik/sources/oauth/types/facebook.py index f871af64bd..697e1b1c4e 100644 --- a/authentik/sources/oauth/types/facebook.py +++ b/authentik/sources/oauth/types/facebook.py @@ -19,16 +19,6 @@ class FacebookOAuthRedirect(OAuthRedirect): class FacebookOAuth2Callback(OAuthCallback): """Facebook OAuth2 Callback""" - def get_user_enroll_context( - self, - info: dict[str, Any], - ) -> dict[str, Any]: - return { - "username": info.get("name"), - "email": info.get("email"), - "name": info.get("name"), - } - @registry.register() class FacebookType(SourceType): @@ -42,3 +32,10 @@ class FacebookType(SourceType): authorization_url = "https://www.facebook.com/v7.0/dialog/oauth" access_token_url = "https://graph.facebook.com/v7.0/oauth/access_token" # nosec profile_url = "https://graph.facebook.com/v7.0/me?fields=id,name,email" + + def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: + return { + "username": info.get("name"), + "email": info.get("email"), + "name": info.get("name"), + } diff --git a/authentik/sources/oauth/types/github.py b/authentik/sources/oauth/types/github.py index ce88ae78d0..a38e4ebb61 100644 --- a/authentik/sources/oauth/types/github.py +++ b/authentik/sources/oauth/types/github.py @@ -5,6 +5,7 @@ from typing import Any from requests.exceptions import RequestException from authentik.sources.oauth.clients.oauth2 import OAuth2Client +from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.types.registry import SourceType, registry from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.redirect import OAuthRedirect @@ -42,26 +43,6 @@ class GitHubOAuth2Callback(OAuthCallback): client_class = GitHubOAuth2Client - def get_user_enroll_context( - self, - info: dict[str, Any], - ) -> dict[str, Any]: - chosen_email = info.get("email") - if not chosen_email: - # The GitHub Userprofile API only returns an email address if the profile - # has a public email address set (despite us asking for user:email, this behaviour - # doesn't change.). So we fetch all the user's email addresses - client: GitHubOAuth2Client = self.get_client(self.source) - emails = client.get_github_emails(self.token) - for email in emails: - if email.get("primary", False): - chosen_email = email.get("email", None) - return { - "username": info.get("login"), - "email": chosen_email, - "name": info.get("name"), - } - @registry.register() class GitHubType(SourceType): @@ -81,3 +62,26 @@ class GitHubType(SourceType): "https://token.actions.githubusercontent.com/.well-known/openid-configuration" ) oidc_jwks_url = "https://token.actions.githubusercontent.com/.well-known/jwks" + + def get_base_user_properties( + self, + source: OAuthSource, + client: GitHubOAuth2Client, + token: dict[str, str], + info: dict[str, Any], + **kwargs, + ) -> dict[str, Any]: + chosen_email = info.get("email") + if not chosen_email: + # The GitHub Userprofile API only returns an email address if the profile + # has a public email address set (despite us asking for user:email, this behaviour + # doesn't change.). So we fetch all the user's email addresses + emails = client.get_github_emails(token) + for email in emails: + if email.get("primary", False): + chosen_email = email.get("email", None) + return { + "username": info.get("login"), + "email": chosen_email, + "name": info.get("name"), + } diff --git a/authentik/sources/oauth/types/gitlab.py b/authentik/sources/oauth/types/gitlab.py index 3d90ea7e5d..bda3f5d92c 100644 --- a/authentik/sources/oauth/types/gitlab.py +++ b/authentik/sources/oauth/types/gitlab.py @@ -25,16 +25,6 @@ class GitLabOAuthRedirect(OAuthRedirect): class GitLabOAuthCallback(OAuthCallback): """GitLab OAuth2 Callback""" - def get_user_enroll_context( - self, - info: dict[str, Any], - ) -> dict[str, Any]: - return { - "username": info.get("preferred_username"), - "email": info.get("email"), - "name": info.get("name"), - } - @registry.register() class GitLabType(SourceType): @@ -52,3 +42,10 @@ class GitLabType(SourceType): profile_url = "https://gitlab.com/oauth/userinfo" oidc_well_known_url = "https://gitlab.com/.well-known/openid-configuration" oidc_jwks_url = "https://gitlab.com/oauth/discovery/keys" + + def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: + return { + "username": info.get("preferred_username"), + "email": info.get("email"), + "name": info.get("name"), + } diff --git a/authentik/sources/oauth/types/google.py b/authentik/sources/oauth/types/google.py index add0eab845..e3ec385848 100644 --- a/authentik/sources/oauth/types/google.py +++ b/authentik/sources/oauth/types/google.py @@ -19,15 +19,6 @@ class GoogleOAuthRedirect(OAuthRedirect): class GoogleOAuth2Callback(OAuthCallback): """Google OAuth2 Callback""" - def get_user_enroll_context( - self, - info: dict[str, Any], - ) -> dict[str, Any]: - return { - "email": info.get("email"), - "name": info.get("name"), - } - @registry.register() class GoogleType(SourceType): @@ -43,3 +34,9 @@ class GoogleType(SourceType): profile_url = "https://www.googleapis.com/oauth2/v1/userinfo" oidc_well_known_url = "https://accounts.google.com/.well-known/openid-configuration" oidc_jwks_url = "https://www.googleapis.com/oauth2/v3/certs" + + def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: + return { + "email": info.get("email"), + "name": info.get("name"), + } diff --git a/authentik/sources/oauth/types/mailcow.py b/authentik/sources/oauth/types/mailcow.py index 37895e114a..b9ef8991cc 100644 --- a/authentik/sources/oauth/types/mailcow.py +++ b/authentik/sources/oauth/types/mailcow.py @@ -47,16 +47,6 @@ class MailcowOAuth2Callback(OAuthCallback): client_class = MailcowOAuth2Client - def get_user_enroll_context( - self, - info: dict[str, Any], - ) -> dict[str, Any]: - return { - "username": info.get("full_name"), - "email": info.get("email"), - "name": info.get("full_name"), - } - @registry.register() class MailcowType(SourceType): @@ -68,3 +58,10 @@ class MailcowType(SourceType): name = "mailcow" urls_customizable = True + + def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: + return { + "username": info.get("full_name"), + "email": info.get("email"), + "name": info.get("full_name"), + } diff --git a/authentik/sources/oauth/types/oidc.py b/authentik/sources/oauth/types/oidc.py index 017e04b5e2..5866b868f1 100644 --- a/authentik/sources/oauth/types/oidc.py +++ b/authentik/sources/oauth/types/oidc.py @@ -26,16 +26,6 @@ class OpenIDConnectOAuth2Callback(OAuthCallback): def get_user_id(self, info: dict[str, str]) -> str: return info.get("sub", None) - def get_user_enroll_context( - self, - info: dict[str, Any], - ) -> dict[str, Any]: - return { - "username": info.get("nickname", info.get("preferred_username")), - "email": info.get("email"), - "name": info.get("name"), - } - @registry.register() class OpenIDConnectType(SourceType): @@ -47,3 +37,11 @@ class OpenIDConnectType(SourceType): name = "openidconnect" urls_customizable = True + + def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: + return { + "username": info.get("nickname", info.get("preferred_username")), + "email": info.get("email"), + "name": info.get("name"), + "groups": info.get("groups", []), + } diff --git a/authentik/sources/oauth/types/okta.py b/authentik/sources/oauth/types/okta.py index 1698cf2d1f..2e43989c5a 100644 --- a/authentik/sources/oauth/types/okta.py +++ b/authentik/sources/oauth/types/okta.py @@ -26,16 +26,6 @@ class OktaOAuth2Callback(OpenIDConnectOAuth2Callback): # see https://github.com/goauthentik/authentik/issues/1910 client_class = UserprofileHeaderAuthClient - def get_user_enroll_context( - self, - info: dict[str, Any], - ) -> dict[str, Any]: - return { - "username": info.get("nickname"), - "email": info.get("email"), - "name": info.get("name"), - } - @registry.register() class OktaType(SourceType): @@ -47,3 +37,11 @@ class OktaType(SourceType): name = "okta" urls_customizable = True + + def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: + return { + "username": info.get("nickname"), + "email": info.get("email"), + "name": info.get("name"), + "groups": info.get("groups", []), + } diff --git a/authentik/sources/oauth/types/patreon.py b/authentik/sources/oauth/types/patreon.py index 07bf307f09..5d0fbd713c 100644 --- a/authentik/sources/oauth/types/patreon.py +++ b/authentik/sources/oauth/types/patreon.py @@ -27,16 +27,6 @@ class PatreonOAuthCallback(OAuthCallback): def get_user_id(self, info: dict[str, str]) -> str: return info.get("data", {}).get("id") - def get_user_enroll_context( - self, - info: dict[str, Any], - ) -> dict[str, Any]: - return { - "username": info.get("data", {}).get("attributes", {}).get("vanity"), - "email": info.get("data", {}).get("attributes", {}).get("email"), - "name": info.get("data", {}).get("attributes", {}).get("full_name"), - } - @registry.register() class PatreonType(SourceType): @@ -50,3 +40,10 @@ class PatreonType(SourceType): authorization_url = "https://www.patreon.com/oauth2/authorize" access_token_url = "https://www.patreon.com/api/oauth2/token" # nosec profile_url = "https://www.patreon.com/api/oauth2/api/current_user" + + def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: + return { + "username": info.get("data", {}).get("attributes", {}).get("vanity"), + "email": info.get("data", {}).get("attributes", {}).get("email"), + "name": info.get("data", {}).get("attributes", {}).get("full_name"), + } diff --git a/authentik/sources/oauth/types/reddit.py b/authentik/sources/oauth/types/reddit.py index 1b901bb37a..c7d1e4a7dc 100644 --- a/authentik/sources/oauth/types/reddit.py +++ b/authentik/sources/oauth/types/reddit.py @@ -34,17 +34,6 @@ class RedditOAuth2Callback(OAuthCallback): client_class = RedditOAuth2Client - def get_user_enroll_context( - self, - info: dict[str, Any], - ) -> dict[str, Any]: - return { - "username": info.get("name"), - "email": None, - "name": info.get("name"), - "password": None, - } - @registry.register() class RedditType(SourceType): @@ -58,3 +47,10 @@ class RedditType(SourceType): authorization_url = "https://www.reddit.com/api/v1/authorize" access_token_url = "https://www.reddit.com/api/v1/access_token" # nosec profile_url = "https://oauth.reddit.com/api/v1/me" + + def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: + return { + "username": info.get("name"), + "email": None, + "name": info.get("name"), + } diff --git a/authentik/sources/oauth/types/registry.py b/authentik/sources/oauth/types/registry.py index 7ff1c85dbe..6a15441e60 100644 --- a/authentik/sources/oauth/types/registry.py +++ b/authentik/sources/oauth/types/registry.py @@ -2,6 +2,7 @@ from collections.abc import Callable from enum import Enum +from typing import Any from django.http.request import HttpRequest from django.templatetags.static import static @@ -55,6 +56,20 @@ class SourceType: } ) + def get_base_user_properties( + self, source: OAuthSource, info: dict[str, Any], **kwargs + ) -> dict[str, Any | dict[str, Any]]: + """Get base user properties for enrollment/update""" + return info + + def get_base_group_properties( + self, source: OAuthSource, group_id: str, **kwargs + ) -> dict[str, Any | dict[str, Any]]: + """Get base group properties for creation/update""" + return { + "name": group_id, + } + class SourceTypeRegistry: """Registry to hold all Source types.""" diff --git a/authentik/sources/oauth/types/twitch.py b/authentik/sources/oauth/types/twitch.py index 777d457867..a4b8fb98de 100644 --- a/authentik/sources/oauth/types/twitch.py +++ b/authentik/sources/oauth/types/twitch.py @@ -33,16 +33,6 @@ class TwitchOAuth2Callback(OpenIDConnectOAuth2Callback): client_class = TwitchClient - def get_user_enroll_context( - self, - info: dict[str, Any], - ) -> dict[str, Any]: - return { - "username": info.get("preferred_username"), - "email": info.get("email"), - "name": info.get("preferred_username"), - } - @registry.register() class TwitchType(SourceType): @@ -56,3 +46,10 @@ class TwitchType(SourceType): authorization_url = "https://id.twitch.tv/oauth2/authorize" access_token_url = "https://id.twitch.tv/oauth2/token" # nosec profile_url = "https://id.twitch.tv/oauth2/userinfo" + + def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: + return { + "username": info.get("preferred_username"), + "email": info.get("email"), + "name": info.get("preferred_username"), + } diff --git a/authentik/sources/oauth/types/twitter.py b/authentik/sources/oauth/types/twitter.py index 8b1aa66124..8e17539f32 100644 --- a/authentik/sources/oauth/types/twitter.py +++ b/authentik/sources/oauth/types/twitter.py @@ -49,17 +49,6 @@ class TwitterOAuthCallback(OAuthCallback): def get_user_id(self, info: dict[str, str]) -> str: return info.get("data", {}).get("id", "") - def get_user_enroll_context( - self, - info: dict[str, Any], - ) -> dict[str, Any]: - data = info.get("data", {}) - return { - "username": data.get("username"), - "email": None, - "name": data.get("name"), - } - @registry.register() class TwitterType(SourceType): @@ -73,3 +62,11 @@ class TwitterType(SourceType): authorization_url = "https://twitter.com/i/oauth2/authorize" access_token_url = "https://api.twitter.com/2/oauth2/token" # nosec profile_url = "https://api.twitter.com/2/users/me" + + def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: + data = info.get("data", {}) + return { + "username": data.get("username"), + "email": None, + "name": data.get("name"), + } diff --git a/authentik/sources/oauth/urls.py b/authentik/sources/oauth/urls.py index 5914f7d017..be256892c6 100644 --- a/authentik/sources/oauth/urls.py +++ b/authentik/sources/oauth/urls.py @@ -2,8 +2,12 @@ from django.urls import path +from authentik.sources.oauth.api.property_mappings import OAuthSourcePropertyMappingViewSet from authentik.sources.oauth.api.source import OAuthSourceViewSet -from authentik.sources.oauth.api.source_connection import UserOAuthSourceConnectionViewSet +from authentik.sources.oauth.api.source_connection import ( + GroupOAuthSourceConnectionViewSet, + UserOAuthSourceConnectionViewSet, +) from authentik.sources.oauth.types.registry import RequestKind from authentik.sources.oauth.views.dispatcher import DispatcherView @@ -21,6 +25,8 @@ urlpatterns = [ ] api_urlpatterns = [ + ("propertymappings/source/oauth", OAuthSourcePropertyMappingViewSet), ("sources/user_connections/oauth", UserOAuthSourceConnectionViewSet), + ("sources/group_connections/oauth", GroupOAuthSourceConnectionViewSet), ("sources/oauth", OAuthSourceViewSet), ] diff --git a/authentik/sources/oauth/views/callback.py b/authentik/sources/oauth/views/callback.py index 79dae2e2ca..6126671aa8 100644 --- a/authentik/sources/oauth/views/callback.py +++ b/authentik/sources/oauth/views/callback.py @@ -13,7 +13,11 @@ from structlog.stdlib import get_logger from authentik.core.sources.flow_manager import SourceFlowManager from authentik.events.models import Event, EventAction -from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection +from authentik.sources.oauth.models import ( + GroupOAuthSourceConnection, + OAuthSource, + UserOAuthSourceConnection, +) from authentik.sources.oauth.views.base import OAuthClientMixin LOGGER = get_logger() @@ -57,15 +61,19 @@ class OAuthCallback(OAuthClientMixin, View): identifier = self.get_user_id(info=raw_info) if identifier is None: return self.handle_login_failure("Could not determine id.") - # Get or create access record - enroll_info = self.get_user_enroll_context(raw_info) sfm = OAuthSourceFlowManager( source=self.source, request=self.request, identifier=identifier, - enroll_info=enroll_info, + user_info={ + "info": raw_info, + "client": client, + "token": self.token, + }, + policy_context={ + "oauth_userinfo": raw_info, + }, ) - sfm.policy_context = {"oauth_userinfo": raw_info} return sfm.get_flow( raw_info=raw_info, access_token=self.token.get("access_token"), @@ -79,13 +87,6 @@ class OAuthCallback(OAuthClientMixin, View): "Return url to redirect on login failure." return settings.LOGIN_URL - def get_user_enroll_context( - self, - info: dict[str, Any], - ) -> dict[str, Any]: - """Create a dict of User data""" - raise NotImplementedError() - def get_user_id(self, info: dict[str, Any]) -> str | None: """Return unique identifier from the profile info.""" if "id" in info: @@ -111,9 +112,10 @@ class OAuthCallback(OAuthClientMixin, View): class OAuthSourceFlowManager(SourceFlowManager): """Flow manager for oauth sources""" - connection_type = UserOAuthSourceConnection + user_connection_type = UserOAuthSourceConnection + group_connection_type = GroupOAuthSourceConnection - def update_connection( + def update_user_connection( self, connection: UserOAuthSourceConnection, access_token: str | None = None, diff --git a/authentik/sources/plex/api/source.py b/authentik/sources/plex/api/source.py index 6d30b6c1e5..4560aea355 100644 --- a/authentik/sources/plex/api/source.py +++ b/authentik/sources/plex/api/source.py @@ -109,7 +109,8 @@ class PlexSourceViewSet(UsedByMixin, ModelViewSet): source=source, request=request, identifier=str(identifier), - enroll_info=user_info, + user_info=user_info, + policy_context={}, ) return to_stage_response(request, sfm.get_flow(plex_token=plex_token)) LOGGER.warning( diff --git a/authentik/sources/plex/plex.py b/authentik/sources/plex/plex.py index caf245888f..60bc5e1c2b 100644 --- a/authentik/sources/plex/plex.py +++ b/authentik/sources/plex/plex.py @@ -113,9 +113,11 @@ class PlexAuth: class PlexSourceFlowManager(SourceFlowManager): """Flow manager for plex sources""" - connection_type = PlexSourceConnection + user_connection_type = PlexSourceConnection - def update_connection(self, connection: PlexSourceConnection, **kwargs) -> PlexSourceConnection: + def update_user_connection( + self, connection: PlexSourceConnection, **kwargs + ) -> PlexSourceConnection: """Set the access_token on the connection""" connection.plex_token = kwargs.get("plex_token") return connection diff --git a/authentik/sources/saml/api/property_mappings.py b/authentik/sources/saml/api/property_mappings.py new file mode 100644 index 0000000000..8153505283 --- /dev/null +++ b/authentik/sources/saml/api/property_mappings.py @@ -0,0 +1,31 @@ +"""SAML source property mappings API""" + +from rest_framework.viewsets import ModelViewSet + +from authentik.core.api.property_mappings import PropertyMappingFilterSet, PropertyMappingSerializer +from authentik.core.api.used_by import UsedByMixin +from authentik.sources.saml.models import SAMLSourcePropertyMapping + + +class SAMLSourcePropertyMappingSerializer(PropertyMappingSerializer): + """SAMLSourcePropertyMapping Serializer""" + + class Meta(PropertyMappingSerializer.Meta): + model = SAMLSourcePropertyMapping + + +class SAMLSourcePropertyMappingFilter(PropertyMappingFilterSet): + """Filter for SAMLSourcePropertyMapping""" + + class Meta(PropertyMappingFilterSet.Meta): + model = SAMLSourcePropertyMapping + + +class SAMLSourcePropertyMappingViewSet(UsedByMixin, ModelViewSet): + """SAMLSourcePropertyMapping Viewset""" + + queryset = SAMLSourcePropertyMapping.objects.all() + serializer_class = SAMLSourcePropertyMappingSerializer + filterset_class = SAMLSourcePropertyMappingFilter + search_fields = ["name"] + ordering = ["name"] diff --git a/authentik/sources/saml/api/source.py b/authentik/sources/saml/api/source.py index a3f0e9bd41..0070797576 100644 --- a/authentik/sources/saml/api/source.py +++ b/authentik/sources/saml/api/source.py @@ -20,6 +20,7 @@ class SAMLSourceSerializer(SourceSerializer): class Meta: model = SAMLSource fields = SourceSerializer.Meta.fields + [ + "group_matching_mode", "pre_authentication_flow", "issuer", "sso_url", diff --git a/authentik/sources/saml/api/source_connection.py b/authentik/sources/saml/api/source_connection.py index b5e276140d..7b97a12035 100644 --- a/authentik/sources/saml/api/source_connection.py +++ b/authentik/sources/saml/api/source_connection.py @@ -3,10 +3,12 @@ from rest_framework.viewsets import ModelViewSet from authentik.core.api.sources import ( + GroupSourceConnectionSerializer, + GroupSourceConnectionViewSet, UserSourceConnectionSerializer, UserSourceConnectionViewSet, ) -from authentik.sources.saml.models import UserSAMLSourceConnection +from authentik.sources.saml.models import GroupSAMLSourceConnection, UserSAMLSourceConnection class UserSAMLSourceConnectionSerializer(UserSourceConnectionSerializer): @@ -22,3 +24,17 @@ class UserSAMLSourceConnectionViewSet(UserSourceConnectionViewSet, ModelViewSet) queryset = UserSAMLSourceConnection.objects.all() serializer_class = UserSAMLSourceConnectionSerializer + + +class GroupSAMLSourceConnectionSerializer(GroupSourceConnectionSerializer): + """OAuth Group-Source connection Serializer""" + + class Meta(GroupSourceConnectionSerializer.Meta): + model = GroupSAMLSourceConnection + + +class GroupSAMLSourceConnectionViewSet(GroupSourceConnectionViewSet): + """Group-source connection Viewset""" + + queryset = GroupSAMLSourceConnection.objects.all() + serializer_class = GroupSAMLSourceConnectionSerializer diff --git a/authentik/sources/saml/migrations/0015_groupsamlsourceconnection_samlsourcepropertymapping.py b/authentik/sources/saml/migrations/0015_groupsamlsourceconnection_samlsourcepropertymapping.py new file mode 100644 index 0000000000..21346442c4 --- /dev/null +++ b/authentik/sources/saml/migrations/0015_groupsamlsourceconnection_samlsourcepropertymapping.py @@ -0,0 +1,57 @@ +# Generated by Django 5.0.7 on 2024-08-01 18:52 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("authentik_core", "0039_source_group_matching_mode_alter_group_name_and_more"), + ("authentik_sources_saml", "0014_alter_samlsource_digest_algorithm_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="GroupSAMLSourceConnection", + fields=[ + ( + "groupsourceconnection_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="authentik_core.groupsourceconnection", + ), + ), + ], + options={ + "verbose_name": "Group SAML Source Connection", + "verbose_name_plural": "Group SAML Source Connections", + }, + bases=("authentik_core.groupsourceconnection",), + ), + migrations.CreateModel( + name="SAMLSourcePropertyMapping", + fields=[ + ( + "propertymapping_ptr", + models.OneToOneField( + auto_created=True, + on_delete=django.db.models.deletion.CASCADE, + parent_link=True, + primary_key=True, + serialize=False, + to="authentik_core.propertymapping", + ), + ), + ], + options={ + "verbose_name": "SAML Source Property Mapping", + "verbose_name_plural": "SAML Source Property Mappings", + }, + bases=("authentik_core.propertymapping",), + ), + ] diff --git a/authentik/sources/saml/models.py b/authentik/sources/saml/models.py index 94179a080f..99c1c2e71f 100644 --- a/authentik/sources/saml/models.py +++ b/authentik/sources/saml/models.py @@ -1,5 +1,7 @@ """saml sp models""" +from typing import Any + from django.db import models from django.http import HttpRequest from django.templatetags.static import static @@ -7,11 +9,17 @@ from django.urls import reverse from django.utils.translation import gettext_lazy as _ from rest_framework.serializers import Serializer -from authentik.core.models import Source, UserSourceConnection +from authentik.core.models import ( + GroupSourceConnection, + PropertyMapping, + Source, + UserSourceConnection, +) from authentik.core.types import UILoginButton, UserSettingSerializer from authentik.crypto.models import CertificateKeyPair from authentik.flows.challenge import RedirectChallenge from authentik.flows.models import Flow +from authentik.lib.expression.evaluator import BaseEvaluator from authentik.lib.utils.time import timedelta_string_validator from authentik.sources.saml.processors.constants import ( DSA_SHA1, @@ -19,10 +27,12 @@ from authentik.sources.saml.processors.constants import ( ECDSA_SHA256, ECDSA_SHA384, ECDSA_SHA512, + NS_SAML_ASSERTION, RSA_SHA1, RSA_SHA256, RSA_SHA384, RSA_SHA512, + SAML_ATTRIBUTES_GROUP, SAML_BINDING_POST, SAML_BINDING_REDIRECT, SAML_NAME_ID_FORMAT_EMAIL, @@ -182,11 +192,39 @@ class SAMLSource(Source): return SAMLSourceSerializer @property - def icon_url(self) -> str: - icon = super().icon_url - if not icon: - return static("authentik/sources/saml.png") - return icon + def property_mapping_type(self) -> type[PropertyMapping]: + return SAMLSourcePropertyMapping + + def get_base_user_properties(self, root: Any, name_id: Any, **kwargs): + attributes = {} + assertion = root.find(f"{{{NS_SAML_ASSERTION}}}Assertion") + if assertion is None: + raise ValueError("Assertion element not found") + attribute_statement = assertion.find(f"{{{NS_SAML_ASSERTION}}}AttributeStatement") + if attribute_statement is None: + raise ValueError("Attribute statement element not found") + # Get all attributes and their values into a dict + for attribute in attribute_statement.iterchildren(): + key = attribute.attrib["Name"] + attributes.setdefault(key, []) + for value in attribute.iterchildren(): + attributes[key].append(value.text) + if SAML_ATTRIBUTES_GROUP in attributes: + attributes["groups"] = attributes[SAML_ATTRIBUTES_GROUP] + del attributes[SAML_ATTRIBUTES_GROUP] + # Flatten all lists in the dict + for key, value in attributes.items(): + if key == "groups": + continue + attributes[key] = BaseEvaluator.expr_flatten(value) + attributes["username"] = name_id.text + + return attributes + + def get_base_group_properties(self, group_id: str, **kwargs): + return { + "name": group_id, + } def get_issuer(self, request: HttpRequest) -> str: """Get Source's Issuer, falling back to our Metadata URL if none is set""" @@ -200,6 +238,13 @@ class SAMLSource(Source): reverse(f"authentik_sources_saml:{view}", kwargs={"source_slug": self.slug}) ) + @property + def icon_url(self) -> str: + icon = super().icon_url + if not icon: + return static("authentik/sources/saml.png") + return icon + def ui_login_button(self, request: HttpRequest) -> UILoginButton: return UILoginButton( challenge=RedirectChallenge( @@ -235,6 +280,24 @@ class SAMLSource(Source): verbose_name_plural = _("SAML Sources") +class SAMLSourcePropertyMapping(PropertyMapping): + """Map SAML properties to User or Group object attributes""" + + @property + def component(self) -> str: + return "ak-property-mapping-saml-source-form" + + @property + def serializer(self) -> type[Serializer]: + from authentik.sources.saml.api.property_mappings import SAMLSourcePropertyMappingSerializer + + return SAMLSourcePropertyMappingSerializer + + class Meta: + verbose_name = _("SAML Source Property Mapping") + verbose_name_plural = _("SAML Source Property Mappings") + + class UserSAMLSourceConnection(UserSourceConnection): """Connection to configured SAML Sources.""" @@ -249,3 +312,19 @@ class UserSAMLSourceConnection(UserSourceConnection): class Meta: verbose_name = _("User SAML Source Connection") verbose_name_plural = _("User SAML Source Connections") + + +class GroupSAMLSourceConnection(GroupSourceConnection): + """Group-source connection""" + + @property + def serializer(self) -> type[Serializer]: + from authentik.sources.saml.api.source_connection import ( + GroupSAMLSourceConnectionSerializer, + ) + + return GroupSAMLSourceConnectionSerializer + + class Meta: + verbose_name = _("Group SAML Source Connection") + verbose_name_plural = _("Group SAML Source Connections") diff --git a/authentik/sources/saml/processors/constants.py b/authentik/sources/saml/processors/constants.py index e0eed95ada..df126c6a44 100644 --- a/authentik/sources/saml/processors/constants.py +++ b/authentik/sources/saml/processors/constants.py @@ -21,6 +21,8 @@ SAML_NAME_ID_FORMAT_X509 = "urn:oasis:names:tc:SAML:2.0:nameid-format:X509Subjec SAML_NAME_ID_FORMAT_WINDOWS = "urn:oasis:names:tc:SAML:2.0:nameid-format:WindowsDomainQualifiedName" SAML_NAME_ID_FORMAT_TRANSIENT = "urn:oasis:names:tc:SAML:2.0:nameid-format:transient" +SAML_ATTRIBUTES_GROUP = "http://schemas.xmlsoap.org/claims/Group" + SAML_BINDING_POST = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" SAML_BINDING_REDIRECT = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" diff --git a/authentik/sources/saml/processors/response.py b/authentik/sources/saml/processors/response.py index b4bbcc890b..62d8dfd0ad 100644 --- a/authentik/sources/saml/processors/response.py +++ b/authentik/sources/saml/processors/response.py @@ -21,16 +21,18 @@ from authentik.core.models import ( User, ) from authentik.core.sources.flow_manager import SourceFlowManager -from authentik.lib.expression.evaluator import BaseEvaluator from authentik.lib.utils.time import timedelta_from_string -from authentik.policies.utils import delete_none_values from authentik.sources.saml.exceptions import ( InvalidSignature, MismatchedRequestID, MissingSAMLResponse, UnsupportedNameIDFormat, ) -from authentik.sources.saml.models import SAMLSource, UserSAMLSourceConnection +from authentik.sources.saml.models import ( + GroupSAMLSourceConnection, + SAMLSource, + UserSAMLSourceConnection, +) from authentik.sources.saml.processors.constants import ( NS_MAP, NS_SAML_ASSERTION, @@ -138,12 +140,12 @@ class ResponseProcessor: user has an attribute that refers to our Source for cleanup. The user is also deleted on logout and periodically.""" # Create a temporary User - name_id = self._get_name_id().text + name_id = self._get_name_id() expiry = mktime( (now() + timedelta_from_string(self._source.temporary_user_delete_after)).timetuple() ) user: User = User.objects.create( - username=name_id, + username=name_id.text, attributes={ USER_ATTRIBUTE_GENERATED: True, USER_ATTRIBUTE_SOURCES: [ @@ -154,15 +156,21 @@ class ResponseProcessor: }, path=self._source.get_user_path(), ) - LOGGER.debug("Created temporary user for NameID Transient", username=name_id) + LOGGER.debug("Created temporary user for NameID Transient", username=name_id.text) user.set_unusable_password() user.save() - UserSAMLSourceConnection.objects.create(source=self._source, user=user, identifier=name_id) + UserSAMLSourceConnection.objects.create( + source=self._source, user=user, identifier=name_id.text + ) return SAMLSourceFlowManager( - self._source, - self._http_request, - name_id, - delete_none_values(self.get_attributes()), + source=self._source, + request=self._http_request, + identifier=str(name_id.text), + user_info={ + "root": self._root, + "name_id": name_id, + }, + policy_context={}, ) def _get_name_id(self) -> "Element": @@ -200,27 +208,6 @@ class ResponseProcessor: f"Assertion contains NameID with unsupported format {_format}." ) - def get_attributes(self) -> dict[str, list[str] | str]: - """Get all attributes sent""" - attributes = {} - assertion = self._root.find(f"{{{NS_SAML_ASSERTION}}}Assertion") - if assertion is None: - raise ValueError("Assertion element not found") - attribute_statement = assertion.find(f"{{{NS_SAML_ASSERTION}}}AttributeStatement") - if attribute_statement is None: - raise ValueError("Attribute statement element not found") - # Get all attributes and their values into a dict - for attribute in attribute_statement.iterchildren(): - key = attribute.attrib["Name"] - attributes.setdefault(key, []) - for value in attribute.iterchildren(): - attributes[key].append(value.text) - # Flatten all lists in the dict - for key, value in attributes.items(): - attributes[key] = BaseEvaluator.expr_flatten(value) - attributes["username"] = self._get_name_id().text - return attributes - def prepare_flow_manager(self) -> SourceFlowManager: """Prepare flow plan depending on whether or not the user exists""" name_id = self._get_name_id() @@ -235,17 +222,22 @@ class ResponseProcessor: if name_id.attrib["Format"] == SAML_NAME_ID_FORMAT_TRANSIENT: return self._handle_name_id_transient() - flow_manager = SAMLSourceFlowManager( - self._source, - self._http_request, - name_id.text, - delete_none_values(self.get_attributes()), + return SAMLSourceFlowManager( + source=self._source, + request=self._http_request, + identifier=str(name_id.text), + user_info={ + "root": self._root, + "name_id": name_id, + }, + policy_context={ + "saml_response": etree.tostring(self._root), + }, ) - flow_manager.policy_context["saml_response"] = etree.tostring(self._root) - return flow_manager class SAMLSourceFlowManager(SourceFlowManager): """Source flow manager for SAML Sources""" - connection_type = UserSAMLSourceConnection + user_connection_type = UserSAMLSourceConnection + group_connection_type = GroupSAMLSourceConnection diff --git a/authentik/sources/saml/tests/fixtures/response_success_groups.xml b/authentik/sources/saml/tests/fixtures/response_success_groups.xml new file mode 100644 index 0000000000..b9c22ac536 --- /dev/null +++ b/authentik/sources/saml/tests/fixtures/response_success_groups.xml @@ -0,0 +1,46 @@ + + + https://accounts.google.com/o/saml2?idpid= + + + + + https://accounts.google.com/o/saml2?idpid= + + jens@goauthentik.io + + + + + + + https://accounts.google.com/o/saml2?idpid= + + + + + foo + + + bar + + + foo@bar.baz + + + group 1 + group 2 + + + + + urn:oasis:names:tc:SAML:2.0:ac:classes:unspecified + + + + diff --git a/authentik/sources/saml/tests/test_property_mappings.py b/authentik/sources/saml/tests/test_property_mappings.py new file mode 100644 index 0000000000..638159913e --- /dev/null +++ b/authentik/sources/saml/tests/test_property_mappings.py @@ -0,0 +1,135 @@ +"""SAML Source tests""" + +from base64 import b64encode + +from defusedxml.lxml import fromstring +from django.contrib.sessions.middleware import SessionMiddleware +from django.test import RequestFactory, TestCase + +from authentik.core.tests.utils import create_test_flow +from authentik.lib.generators import generate_id +from authentik.lib.tests.utils import dummy_get_response, load_fixture +from authentik.sources.saml.models import SAMLSource, SAMLSourcePropertyMapping +from authentik.sources.saml.processors.constants import NS_SAML_ASSERTION +from authentik.sources.saml.processors.response import ResponseProcessor + +ROOT = fromstring(load_fixture("fixtures/response_success.xml").encode()) +ROOT_GROUPS = fromstring(load_fixture("fixtures/response_success_groups.xml").encode()) +NAME_ID = ( + ROOT.find(f"{{{NS_SAML_ASSERTION}}}Assertion") + .find(f"{{{NS_SAML_ASSERTION}}}Subject") + .find(f"{{{NS_SAML_ASSERTION}}}NameID") +) + + +class TestPropertyMappings(TestCase): + """Test Property Mappings""" + + def setUp(self): + self.factory = RequestFactory() + self.source = SAMLSource.objects.create( + slug=generate_id(), + issuer="authentik", + allow_idp_initiated=True, + pre_authentication_flow=create_test_flow(), + ) + + def test_user_base_properties(self): + """Test user base properties""" + properties = self.source.get_base_user_properties(root=ROOT, name_id=NAME_ID) + self.assertEqual( + properties, + { + "email": "foo@bar.baz", + "name": "foo", + "sn": "bar", + "username": "jens@goauthentik.io", + }, + ) + + def test_group_base_properties(self): + """Test group base properties""" + properties = self.source.get_base_user_properties(root=ROOT_GROUPS, name_id=NAME_ID) + self.assertEqual(properties["groups"], ["group 1", "group 2"]) + for group_id in ["group 1", "group 2"]: + properties = self.source.get_base_group_properties(root=ROOT, group_id=group_id) + self.assertEqual(properties, {"name": group_id}) + + def test_user_property_mappings(self): + """Test user property mappings""" + self.source.user_property_mappings.add( + SAMLSourcePropertyMapping.objects.create( + name="test", + expression="return {'attributes': {'department': 'Engineering'}, 'sn': None}", + ) + ) + request = self.factory.post( + "/", + data={ + "SAMLResponse": b64encode( + load_fixture("fixtures/response_success.xml").encode() + ).decode() + }, + ) + + middleware = SessionMiddleware(dummy_get_response) + middleware.process_request(request) + request.session.save() + + parser = ResponseProcessor(self.source, request) + parser.parse() + sfm = parser.prepare_flow_manager() + self.assertEqual( + sfm.user_properties, + { + "email": "foo@bar.baz", + "name": "foo", + "username": "jens@goauthentik.io", + "attributes": { + "department": "Engineering", + }, + "path": self.source.get_user_path(), + }, + ) + + def test_group_property_mappings(self): + """Test group property mappings""" + self.source.group_property_mappings.add( + SAMLSourcePropertyMapping.objects.create( + name="test", + expression="return {'attributes': {'id': group_id}}", + ) + ) + request = self.factory.post( + "/", + data={ + "SAMLResponse": b64encode( + load_fixture("fixtures/response_success_groups.xml").encode() + ).decode() + }, + ) + + middleware = SessionMiddleware(dummy_get_response) + middleware.process_request(request) + request.session.save() + + parser = ResponseProcessor(self.source, request) + parser.parse() + sfm = parser.prepare_flow_manager() + self.assertEqual( + sfm.groups_properties, + { + "group 1": { + "name": "group 1", + "attributes": { + "id": "group 1", + }, + }, + "group 2": { + "name": "group 2", + "attributes": { + "id": "group 2", + }, + }, + }, + ) diff --git a/authentik/sources/saml/tests/test_response.py b/authentik/sources/saml/tests/test_response.py index b22957c8d3..a56e3d4c19 100644 --- a/authentik/sources/saml/tests/test_response.py +++ b/authentik/sources/saml/tests/test_response.py @@ -67,6 +67,13 @@ class TestResponseProcessor(TestCase): parser.parse() sfm = parser.prepare_flow_manager() self.assertEqual( - sfm.enroll_info, - {"email": "foo@bar.baz", "name": "foo", "sn": "bar", "username": "jens@goauthentik.io"}, + sfm.user_properties, + { + "email": "foo@bar.baz", + "name": "foo", + "sn": "bar", + "username": "jens@goauthentik.io", + "attributes": {}, + "path": self.source.get_user_path(), + }, ) diff --git a/authentik/sources/saml/urls.py b/authentik/sources/saml/urls.py index 6abeb3f4db..745062fb70 100644 --- a/authentik/sources/saml/urls.py +++ b/authentik/sources/saml/urls.py @@ -2,8 +2,12 @@ from django.urls import path +from authentik.sources.saml.api.property_mappings import SAMLSourcePropertyMappingViewSet from authentik.sources.saml.api.source import SAMLSourceViewSet -from authentik.sources.saml.api.source_connection import UserSAMLSourceConnectionViewSet +from authentik.sources.saml.api.source_connection import ( + GroupSAMLSourceConnectionViewSet, + UserSAMLSourceConnectionViewSet, +) from authentik.sources.saml.views import ACSView, InitiateView, MetadataView, SLOView urlpatterns = [ @@ -14,6 +18,8 @@ urlpatterns = [ ] api_urlpatterns = [ + ("propertymappings/source/saml", SAMLSourcePropertyMappingViewSet), ("sources/user_connections/saml", UserSAMLSourceConnectionViewSet), + ("sources/group_connections/saml", GroupSAMLSourceConnectionViewSet), ("sources/saml", SAMLSourceViewSet), ] diff --git a/blueprints/schema.json b/blueprints/schema.json index cf2d42abc4..0a7811f3b6 100644 --- a/blueprints/schema.json +++ b/blueprints/schema.json @@ -1201,6 +1201,46 @@ } } }, + { + "type": "object", + "required": [ + "model", + "identifiers" + ], + "properties": { + "model": { + "const": "authentik_sources_oauth.oauthsourcepropertymapping" + }, + "id": { + "type": "string" + }, + "state": { + "type": "string", + "enum": [ + "absent", + "present", + "created", + "must_created" + ], + "default": "present" + }, + "conditions": { + "type": "array", + "items": { + "type": "boolean" + } + }, + "permissions": { + "$ref": "#/$defs/model_authentik_sources_oauth.oauthsourcepropertymapping_permissions" + }, + "attrs": { + "$ref": "#/$defs/model_authentik_sources_oauth.oauthsourcepropertymapping" + }, + "identifiers": { + "$ref": "#/$defs/model_authentik_sources_oauth.oauthsourcepropertymapping" + } + } + }, { "type": "object", "required": [ @@ -1241,6 +1281,46 @@ } } }, + { + "type": "object", + "required": [ + "model", + "identifiers" + ], + "properties": { + "model": { + "const": "authentik_sources_oauth.groupoauthsourceconnection" + }, + "id": { + "type": "string" + }, + "state": { + "type": "string", + "enum": [ + "absent", + "present", + "created", + "must_created" + ], + "default": "present" + }, + "conditions": { + "type": "array", + "items": { + "type": "boolean" + } + }, + "permissions": { + "$ref": "#/$defs/model_authentik_sources_oauth.groupoauthsourceconnection_permissions" + }, + "attrs": { + "$ref": "#/$defs/model_authentik_sources_oauth.groupoauthsourceconnection" + }, + "identifiers": { + "$ref": "#/$defs/model_authentik_sources_oauth.groupoauthsourceconnection" + } + } + }, { "type": "object", "required": [ @@ -1361,6 +1441,46 @@ } } }, + { + "type": "object", + "required": [ + "model", + "identifiers" + ], + "properties": { + "model": { + "const": "authentik_sources_saml.samlsourcepropertymapping" + }, + "id": { + "type": "string" + }, + "state": { + "type": "string", + "enum": [ + "absent", + "present", + "created", + "must_created" + ], + "default": "present" + }, + "conditions": { + "type": "array", + "items": { + "type": "boolean" + } + }, + "permissions": { + "$ref": "#/$defs/model_authentik_sources_saml.samlsourcepropertymapping_permissions" + }, + "attrs": { + "$ref": "#/$defs/model_authentik_sources_saml.samlsourcepropertymapping" + }, + "identifiers": { + "$ref": "#/$defs/model_authentik_sources_saml.samlsourcepropertymapping" + } + } + }, { "type": "object", "required": [ @@ -1401,6 +1521,46 @@ } } }, + { + "type": "object", + "required": [ + "model", + "identifiers" + ], + "properties": { + "model": { + "const": "authentik_sources_saml.groupsamlsourceconnection" + }, + "id": { + "type": "string" + }, + "state": { + "type": "string", + "enum": [ + "absent", + "present", + "created", + "must_created" + ], + "default": "present" + }, + "conditions": { + "type": "array", + "items": { + "type": "boolean" + } + }, + "permissions": { + "$ref": "#/$defs/model_authentik_sources_saml.groupsamlsourceconnection_permissions" + }, + "attrs": { + "$ref": "#/$defs/model_authentik_sources_saml.groupsamlsourceconnection" + }, + "identifiers": { + "$ref": "#/$defs/model_authentik_sources_saml.groupsamlsourceconnection" + } + } + }, { "type": "object", "required": [ @@ -4106,11 +4266,15 @@ "authentik_sources_ldap.ldapsource", "authentik_sources_ldap.ldapsourcepropertymapping", "authentik_sources_oauth.oauthsource", + "authentik_sources_oauth.oauthsourcepropertymapping", "authentik_sources_oauth.useroauthsourceconnection", + "authentik_sources_oauth.groupoauthsourceconnection", "authentik_sources_plex.plexsource", "authentik_sources_plex.plexsourceconnection", "authentik_sources_saml.samlsource", + "authentik_sources_saml.samlsourcepropertymapping", "authentik_sources_saml.usersamlsourceconnection", + "authentik_sources_saml.groupsamlsourceconnection", "authentik_sources_scim.scimsource", "authentik_sources_scim.scimsourcepropertymapping", "authentik_stages_authenticator_duo.authenticatorduostage", @@ -6615,6 +6779,16 @@ "minLength": 1, "title": "Icon" }, + "group_matching_mode": { + "type": "string", + "enum": [ + "identifier", + "name_link", + "name_deny" + ], + "title": "Group matching mode", + "description": "How the source determines if an existing group should be used or a new group created." + }, "provider_type": { "type": "string", "enum": [ @@ -6727,6 +6901,57 @@ } } }, + "model_authentik_sources_oauth.oauthsourcepropertymapping": { + "type": "object", + "properties": { + "managed": { + "type": [ + "string", + "null" + ], + "minLength": 1, + "title": "Managed by authentik", + "description": "Objects that are managed by authentik. These objects are created and updated automatically. This flag only indicates that an object can be overwritten by migrations. You can still modify the objects via the API, but expect changes to be overwritten in a later update." + }, + "name": { + "type": "string", + "minLength": 1, + "title": "Name" + }, + "expression": { + "type": "string", + "minLength": 1, + "title": "Expression" + } + }, + "required": [] + }, + "model_authentik_sources_oauth.oauthsourcepropertymapping_permissions": { + "type": "array", + "items": { + "type": "object", + "required": [ + "permission" + ], + "properties": { + "permission": { + "type": "string", + "enum": [ + "add_oauthsourcepropertymapping", + "change_oauthsourcepropertymapping", + "delete_oauthsourcepropertymapping", + "view_oauthsourcepropertymapping" + ] + }, + "user": { + "type": "integer" + }, + "role": { + "type": "string" + } + } + } + }, "model_authentik_sources_oauth.useroauthsourceconnection": { "type": "object", "properties": { @@ -6777,6 +7002,43 @@ } } }, + "model_authentik_sources_oauth.groupoauthsourceconnection": { + "type": "object", + "properties": { + "icon": { + "type": "string", + "minLength": 1, + "title": "Icon" + } + }, + "required": [] + }, + "model_authentik_sources_oauth.groupoauthsourceconnection_permissions": { + "type": "array", + "items": { + "type": "object", + "required": [ + "permission" + ], + "properties": { + "permission": { + "type": "string", + "enum": [ + "add_groupoauthsourceconnection", + "change_groupoauthsourceconnection", + "delete_groupoauthsourceconnection", + "view_groupoauthsourceconnection" + ] + }, + "user": { + "type": "integer" + }, + "role": { + "type": "string" + } + } + } + }, "model_authentik_sources_plex.plexsource": { "type": "object", "properties": { @@ -7038,6 +7300,16 @@ "minLength": 1, "title": "Icon" }, + "group_matching_mode": { + "type": "string", + "enum": [ + "identifier", + "name_link", + "name_deny" + ], + "title": "Group matching mode", + "description": "How the source determines if an existing group should be used or a new group created." + }, "pre_authentication_flow": { "type": "string", "format": "uuid", @@ -7165,6 +7437,57 @@ } } }, + "model_authentik_sources_saml.samlsourcepropertymapping": { + "type": "object", + "properties": { + "managed": { + "type": [ + "string", + "null" + ], + "minLength": 1, + "title": "Managed by authentik", + "description": "Objects that are managed by authentik. These objects are created and updated automatically. This flag only indicates that an object can be overwritten by migrations. You can still modify the objects via the API, but expect changes to be overwritten in a later update." + }, + "name": { + "type": "string", + "minLength": 1, + "title": "Name" + }, + "expression": { + "type": "string", + "minLength": 1, + "title": "Expression" + } + }, + "required": [] + }, + "model_authentik_sources_saml.samlsourcepropertymapping_permissions": { + "type": "array", + "items": { + "type": "object", + "required": [ + "permission" + ], + "properties": { + "permission": { + "type": "string", + "enum": [ + "add_samlsourcepropertymapping", + "change_samlsourcepropertymapping", + "delete_samlsourcepropertymapping", + "view_samlsourcepropertymapping" + ] + }, + "user": { + "type": "integer" + }, + "role": { + "type": "string" + } + } + } + }, "model_authentik_sources_saml.usersamlsourceconnection": { "type": "object", "properties": { @@ -7207,6 +7530,43 @@ } } }, + "model_authentik_sources_saml.groupsamlsourceconnection": { + "type": "object", + "properties": { + "icon": { + "type": "string", + "minLength": 1, + "title": "Icon" + } + }, + "required": [] + }, + "model_authentik_sources_saml.groupsamlsourceconnection_permissions": { + "type": "array", + "items": { + "type": "object", + "required": [ + "permission" + ], + "properties": { + "permission": { + "type": "string", + "enum": [ + "add_groupsamlsourceconnection", + "change_groupsamlsourceconnection", + "delete_groupsamlsourceconnection", + "view_groupsamlsourceconnection" + ] + }, + "user": { + "type": "integer" + }, + "role": { + "type": "string" + } + } + } + }, "model_authentik_sources_scim.scimsource": { "type": "object", "properties": { @@ -10969,7 +11329,6 @@ "properties": { "name": { "type": "string", - "maxLength": 80, "minLength": 1, "title": "Name" }, diff --git a/schema.yml b/schema.yml index 5c2ffc5701..ba413f6786 100644 --- a/schema.yml +++ b/schema.yml @@ -16259,6 +16259,568 @@ paths: schema: $ref: '#/components/schemas/GenericError' description: '' + /propertymappings/source/oauth/: + get: + operationId: propertymappings_source_oauth_list + description: OAuthSourcePropertyMapping Viewset + parameters: + - in: query + name: managed + schema: + type: array + items: + type: string + explode: true + style: form + - in: query + name: managed__isnull + schema: + type: boolean + - in: query + name: name + schema: + type: string + - name: ordering + required: false + in: query + description: Which field to use when ordering the results. + schema: + type: string + - name: page + required: false + in: query + description: A page number within the paginated result set. + schema: + type: integer + - name: page_size + required: false + in: query + description: Number of results to return per page. + schema: + type: integer + - name: search + required: false + in: query + description: A search term. + schema: + type: string + tags: + - propertymappings + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/PaginatedOAuthSourcePropertyMappingList' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + post: + operationId: propertymappings_source_oauth_create + description: OAuthSourcePropertyMapping Viewset + tags: + - propertymappings + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/OAuthSourcePropertyMappingRequest' + required: true + security: + - authentik: [] + responses: + '201': + content: + application/json: + schema: + $ref: '#/components/schemas/OAuthSourcePropertyMapping' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + /propertymappings/source/oauth/{pm_uuid}/: + get: + operationId: propertymappings_source_oauth_retrieve + description: OAuthSourcePropertyMapping Viewset + parameters: + - in: path + name: pm_uuid + schema: + type: string + format: uuid + description: A UUID string identifying this OAuth Source Property Mapping. + required: true + tags: + - propertymappings + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/OAuthSourcePropertyMapping' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + put: + operationId: propertymappings_source_oauth_update + description: OAuthSourcePropertyMapping Viewset + parameters: + - in: path + name: pm_uuid + schema: + type: string + format: uuid + description: A UUID string identifying this OAuth Source Property Mapping. + required: true + tags: + - propertymappings + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/OAuthSourcePropertyMappingRequest' + required: true + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/OAuthSourcePropertyMapping' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + patch: + operationId: propertymappings_source_oauth_partial_update + description: OAuthSourcePropertyMapping Viewset + parameters: + - in: path + name: pm_uuid + schema: + type: string + format: uuid + description: A UUID string identifying this OAuth Source Property Mapping. + required: true + tags: + - propertymappings + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PatchedOAuthSourcePropertyMappingRequest' + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/OAuthSourcePropertyMapping' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + delete: + operationId: propertymappings_source_oauth_destroy + description: OAuthSourcePropertyMapping Viewset + parameters: + - in: path + name: pm_uuid + schema: + type: string + format: uuid + description: A UUID string identifying this OAuth Source Property Mapping. + required: true + tags: + - propertymappings + security: + - authentik: [] + responses: + '204': + description: No response body + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + /propertymappings/source/oauth/{pm_uuid}/used_by/: + get: + operationId: propertymappings_source_oauth_used_by_list + description: Get a list of all objects that use this object + parameters: + - in: path + name: pm_uuid + schema: + type: string + format: uuid + description: A UUID string identifying this OAuth Source Property Mapping. + required: true + tags: + - propertymappings + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/UsedBy' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + /propertymappings/source/saml/: + get: + operationId: propertymappings_source_saml_list + description: SAMLSourcePropertyMapping Viewset + parameters: + - in: query + name: managed + schema: + type: array + items: + type: string + explode: true + style: form + - in: query + name: managed__isnull + schema: + type: boolean + - in: query + name: name + schema: + type: string + - name: ordering + required: false + in: query + description: Which field to use when ordering the results. + schema: + type: string + - name: page + required: false + in: query + description: A page number within the paginated result set. + schema: + type: integer + - name: page_size + required: false + in: query + description: Number of results to return per page. + schema: + type: integer + - name: search + required: false + in: query + description: A search term. + schema: + type: string + tags: + - propertymappings + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/PaginatedSAMLSourcePropertyMappingList' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + post: + operationId: propertymappings_source_saml_create + description: SAMLSourcePropertyMapping Viewset + tags: + - propertymappings + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/SAMLSourcePropertyMappingRequest' + required: true + security: + - authentik: [] + responses: + '201': + content: + application/json: + schema: + $ref: '#/components/schemas/SAMLSourcePropertyMapping' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + /propertymappings/source/saml/{pm_uuid}/: + get: + operationId: propertymappings_source_saml_retrieve + description: SAMLSourcePropertyMapping Viewset + parameters: + - in: path + name: pm_uuid + schema: + type: string + format: uuid + description: A UUID string identifying this SAML Source Property Mapping. + required: true + tags: + - propertymappings + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/SAMLSourcePropertyMapping' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + put: + operationId: propertymappings_source_saml_update + description: SAMLSourcePropertyMapping Viewset + parameters: + - in: path + name: pm_uuid + schema: + type: string + format: uuid + description: A UUID string identifying this SAML Source Property Mapping. + required: true + tags: + - propertymappings + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/SAMLSourcePropertyMappingRequest' + required: true + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/SAMLSourcePropertyMapping' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + patch: + operationId: propertymappings_source_saml_partial_update + description: SAMLSourcePropertyMapping Viewset + parameters: + - in: path + name: pm_uuid + schema: + type: string + format: uuid + description: A UUID string identifying this SAML Source Property Mapping. + required: true + tags: + - propertymappings + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PatchedSAMLSourcePropertyMappingRequest' + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/SAMLSourcePropertyMapping' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + delete: + operationId: propertymappings_source_saml_destroy + description: SAMLSourcePropertyMapping Viewset + parameters: + - in: path + name: pm_uuid + schema: + type: string + format: uuid + description: A UUID string identifying this SAML Source Property Mapping. + required: true + tags: + - propertymappings + security: + - authentik: [] + responses: + '204': + description: No response body + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + /propertymappings/source/saml/{pm_uuid}/used_by/: + get: + operationId: propertymappings_source_saml_used_by_list + description: Get a list of all objects that use this object + parameters: + - in: path + name: pm_uuid + schema: + type: string + format: uuid + description: A UUID string identifying this SAML Source Property Mapping. + required: true + tags: + - propertymappings + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/UsedBy' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' /propertymappings/source/scim/: get: operationId: propertymappings_source_scim_list @@ -21609,11 +22171,15 @@ paths: - authentik_rbac.role - authentik_sources_ldap.ldapsource - authentik_sources_ldap.ldapsourcepropertymapping + - authentik_sources_oauth.groupoauthsourceconnection - authentik_sources_oauth.oauthsource + - authentik_sources_oauth.oauthsourcepropertymapping - authentik_sources_oauth.useroauthsourceconnection - authentik_sources_plex.plexsource - authentik_sources_plex.plexsourceconnection + - authentik_sources_saml.groupsamlsourceconnection - authentik_sources_saml.samlsource + - authentik_sources_saml.samlsourcepropertymapping - authentik_sources_saml.usersamlsourceconnection - authentik_sources_scim.scimsource - authentik_sources_scim.scimsourcepropertymapping @@ -21837,11 +22403,15 @@ paths: - authentik_rbac.role - authentik_sources_ldap.ldapsource - authentik_sources_ldap.ldapsourcepropertymapping + - authentik_sources_oauth.groupoauthsourceconnection - authentik_sources_oauth.oauthsource + - authentik_sources_oauth.oauthsourcepropertymapping - authentik_sources_oauth.useroauthsourceconnection - authentik_sources_plex.plexsource - authentik_sources_plex.plexsourceconnection + - authentik_sources_saml.groupsamlsourceconnection - authentik_sources_saml.samlsource + - authentik_sources_saml.samlsourcepropertymapping - authentik_sources_saml.usersamlsourceconnection - authentik_sources_scim.scimsource - authentik_sources_scim.scimsourcepropertymapping @@ -23136,6 +23706,484 @@ paths: schema: $ref: '#/components/schemas/GenericError' description: '' + /sources/group_connections/oauth/: + get: + operationId: sources_group_connections_oauth_list + description: Group-source connection Viewset + parameters: + - in: query + name: group + schema: + type: string + format: uuid + - name: ordering + required: false + in: query + description: Which field to use when ordering the results. + schema: + type: string + - name: page + required: false + in: query + description: A page number within the paginated result set. + schema: + type: integer + - name: page_size + required: false + in: query + description: Number of results to return per page. + schema: + type: integer + - name: search + required: false + in: query + description: A search term. + schema: + type: string + - in: query + name: source__slug + schema: + type: string + tags: + - sources + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/PaginatedGroupOAuthSourceConnectionList' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + post: + operationId: sources_group_connections_oauth_create + description: Group-source connection Viewset + tags: + - sources + security: + - authentik: [] + responses: + '201': + content: + application/json: + schema: + $ref: '#/components/schemas/GroupOAuthSourceConnection' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + /sources/group_connections/oauth/{id}/: + get: + operationId: sources_group_connections_oauth_retrieve + description: Group-source connection Viewset + parameters: + - in: path + name: id + schema: + type: integer + description: A unique integer value identifying this Group OAuth Source Connection. + required: true + tags: + - sources + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/GroupOAuthSourceConnection' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + put: + operationId: sources_group_connections_oauth_update + description: Group-source connection Viewset + parameters: + - in: path + name: id + schema: + type: integer + description: A unique integer value identifying this Group OAuth Source Connection. + required: true + tags: + - sources + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/GroupOAuthSourceConnection' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + patch: + operationId: sources_group_connections_oauth_partial_update + description: Group-source connection Viewset + parameters: + - in: path + name: id + schema: + type: integer + description: A unique integer value identifying this Group OAuth Source Connection. + required: true + tags: + - sources + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/GroupOAuthSourceConnection' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + delete: + operationId: sources_group_connections_oauth_destroy + description: Group-source connection Viewset + parameters: + - in: path + name: id + schema: + type: integer + description: A unique integer value identifying this Group OAuth Source Connection. + required: true + tags: + - sources + security: + - authentik: [] + responses: + '204': + description: No response body + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + /sources/group_connections/oauth/{id}/used_by/: + get: + operationId: sources_group_connections_oauth_used_by_list + description: Get a list of all objects that use this object + parameters: + - in: path + name: id + schema: + type: integer + description: A unique integer value identifying this Group OAuth Source Connection. + required: true + tags: + - sources + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/UsedBy' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + /sources/group_connections/saml/: + get: + operationId: sources_group_connections_saml_list + description: Group-source connection Viewset + parameters: + - in: query + name: group + schema: + type: string + format: uuid + - name: ordering + required: false + in: query + description: Which field to use when ordering the results. + schema: + type: string + - name: page + required: false + in: query + description: A page number within the paginated result set. + schema: + type: integer + - name: page_size + required: false + in: query + description: Number of results to return per page. + schema: + type: integer + - name: search + required: false + in: query + description: A search term. + schema: + type: string + - in: query + name: source__slug + schema: + type: string + tags: + - sources + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/PaginatedGroupSAMLSourceConnectionList' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + /sources/group_connections/saml/{id}/: + get: + operationId: sources_group_connections_saml_retrieve + description: Group-source connection Viewset + parameters: + - in: path + name: id + schema: + type: integer + description: A unique integer value identifying this Group SAML Source Connection. + required: true + tags: + - sources + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/GroupSAMLSourceConnection' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + put: + operationId: sources_group_connections_saml_update + description: Group-source connection Viewset + parameters: + - in: path + name: id + schema: + type: integer + description: A unique integer value identifying this Group SAML Source Connection. + required: true + tags: + - sources + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/GroupSAMLSourceConnection' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + patch: + operationId: sources_group_connections_saml_partial_update + description: Group-source connection Viewset + parameters: + - in: path + name: id + schema: + type: integer + description: A unique integer value identifying this Group SAML Source Connection. + required: true + tags: + - sources + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/GroupSAMLSourceConnection' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + delete: + operationId: sources_group_connections_saml_destroy + description: Group-source connection Viewset + parameters: + - in: path + name: id + schema: + type: integer + description: A unique integer value identifying this Group SAML Source Connection. + required: true + tags: + - sources + security: + - authentik: [] + responses: + '204': + description: No response body + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' + /sources/group_connections/saml/{id}/used_by/: + get: + operationId: sources_group_connections_saml_used_by_list + description: Get a list of all objects that use this object + parameters: + - in: path + name: id + schema: + type: integer + description: A unique integer value identifying this Group SAML Source Connection. + required: true + tags: + - sources + security: + - authentik: [] + responses: + '200': + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/UsedBy' + description: '' + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/ValidationError' + description: '' + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/GenericError' + description: '' /sources/ldap/: get: operationId: sources_ldap_list @@ -23604,6 +24652,17 @@ paths: schema: type: string format: uuid + - in: query + name: group_matching_mode + schema: + type: string + enum: + - identifier + - name_deny + - name_link + description: |+ + How the source determines if an existing group should be used or a new group created. + - in: query name: has_jwks schema: @@ -38423,7 +39482,6 @@ components: readOnly: true name: type: string - maxLength: 80 is_superuser: type: boolean description: Users added to this group will be superusers. @@ -38465,6 +39523,12 @@ components: - pk - roles_obj - users_obj + GroupMatchingModeEnum: + enum: + - identifier + - name_link + - name_deny + type: string GroupMember: type: object description: Stripped down user serializer to show relevant users for groups @@ -38542,6 +39606,35 @@ components: required: - name - username + GroupOAuthSourceConnection: + type: object + description: OAuth Group-Source connection Serializer + properties: + pk: + type: integer + readOnly: true + title: ID + group: + type: string + format: uuid + readOnly: true + source: + allOf: + - $ref: '#/components/schemas/Source' + readOnly: true + identifier: + type: string + readOnly: true + created: + type: string + format: date-time + readOnly: true + required: + - created + - group + - identifier + - pk + - source GroupRequest: type: object description: Group Serializer @@ -38549,7 +39642,6 @@ components: name: type: string minLength: 1 - maxLength: 80 is_superuser: type: boolean description: Users added to this group will be superusers. @@ -38571,6 +39663,35 @@ components: format: uuid required: - name + GroupSAMLSourceConnection: + type: object + description: OAuth Group-Source connection Serializer + properties: + pk: + type: integer + readOnly: true + title: ID + group: + type: string + format: uuid + readOnly: true + source: + allOf: + - $ref: '#/components/schemas/Source' + readOnly: true + identifier: + type: string + readOnly: true + created: + type: string + format: date-time + readOnly: true + required: + - created + - group + - identifier + - pk + - source IdentificationChallenge: type: object description: Identification challenges with all UI elements @@ -40110,11 +41231,15 @@ components: - authentik_sources_ldap.ldapsource - authentik_sources_ldap.ldapsourcepropertymapping - authentik_sources_oauth.oauthsource + - authentik_sources_oauth.oauthsourcepropertymapping - authentik_sources_oauth.useroauthsourceconnection + - authentik_sources_oauth.groupoauthsourceconnection - authentik_sources_plex.plexsource - authentik_sources_plex.plexsourceconnection - authentik_sources_saml.samlsource + - authentik_sources_saml.samlsourcepropertymapping - authentik_sources_saml.usersamlsourceconnection + - authentik_sources_saml.groupsamlsourceconnection - authentik_sources_scim.scimsource - authentik_sources_scim.scimsourcepropertymapping - authentik_stages_authenticator_duo.authenticatorduostage @@ -40754,6 +41879,11 @@ components: type: string nullable: true readOnly: true + group_matching_mode: + allOf: + - $ref: '#/components/schemas/GroupMatchingModeEnum' + description: How the source determines if an existing group should be used + or a new group created. provider_type: $ref: '#/components/schemas/ProviderTypeEnum' request_token_url: @@ -40808,6 +41938,73 @@ components: - type - verbose_name - verbose_name_plural + OAuthSourcePropertyMapping: + type: object + description: OAuthSourcePropertyMapping Serializer + properties: + pk: + type: string + format: uuid + readOnly: true + title: Pm uuid + managed: + type: string + nullable: true + title: Managed by authentik + description: Objects that are managed by authentik. These objects are created + and updated automatically. This flag only indicates that an object can + be overwritten by migrations. You can still modify the objects via the + API, but expect changes to be overwritten in a later update. + name: + type: string + expression: + type: string + component: + type: string + description: Get object's component so that we know how to edit the object + readOnly: true + verbose_name: + type: string + description: Return object's verbose_name + readOnly: true + verbose_name_plural: + type: string + description: Return object's plural verbose_name + readOnly: true + meta_model_name: + type: string + description: Return internal model name + readOnly: true + required: + - component + - expression + - meta_model_name + - name + - pk + - verbose_name + - verbose_name_plural + OAuthSourcePropertyMappingRequest: + type: object + description: OAuthSourcePropertyMapping Serializer + properties: + managed: + type: string + nullable: true + minLength: 1 + title: Managed by authentik + description: Objects that are managed by authentik. These objects are created + and updated automatically. This flag only indicates that an object can + be overwritten by migrations. You can still modify the objects via the + API, but expect changes to be overwritten in a later update. + name: + type: string + minLength: 1 + expression: + type: string + minLength: 1 + required: + - expression + - name OAuthSourceRequest: type: object description: OAuth Source Serializer @@ -40854,6 +42051,11 @@ components: user_path_template: type: string minLength: 1 + group_matching_mode: + allOf: + - $ref: '#/components/schemas/GroupMatchingModeEnum' + description: How the source determines if an existing group should be used + or a new group created. provider_type: $ref: '#/components/schemas/ProviderTypeEnum' request_token_url: @@ -41550,6 +42752,30 @@ components: required: - pagination - results + PaginatedGroupOAuthSourceConnectionList: + type: object + properties: + pagination: + $ref: '#/components/schemas/Pagination' + results: + type: array + items: + $ref: '#/components/schemas/GroupOAuthSourceConnection' + required: + - pagination + - results + PaginatedGroupSAMLSourceConnectionList: + type: object + properties: + pagination: + $ref: '#/components/schemas/Pagination' + results: + type: array + items: + $ref: '#/components/schemas/GroupSAMLSourceConnection' + required: + - pagination + - results PaginatedIdentificationStageList: type: object properties: @@ -41778,6 +43004,18 @@ components: required: - pagination - results + PaginatedOAuthSourcePropertyMappingList: + type: object + properties: + pagination: + $ref: '#/components/schemas/Pagination' + results: + type: array + items: + $ref: '#/components/schemas/OAuthSourcePropertyMapping' + required: + - pagination + - results PaginatedOutpostList: type: object properties: @@ -42102,6 +43340,18 @@ components: required: - pagination - results + PaginatedSAMLSourcePropertyMappingList: + type: object + properties: + pagination: + $ref: '#/components/schemas/Pagination' + results: + type: array + items: + $ref: '#/components/schemas/SAMLSourcePropertyMapping' + required: + - pagination + - results PaginatedSCIMMappingList: type: object properties: @@ -43682,7 +44932,6 @@ components: name: type: string minLength: 1 - maxLength: 80 is_superuser: type: boolean description: Users added to this group will be superusers. @@ -44225,6 +45474,25 @@ components: title: Any JWT signed by the JWK of the selected source can be used to authenticate. title: Any JWT signed by the JWK of the selected source can be used to authenticate. + PatchedOAuthSourcePropertyMappingRequest: + type: object + description: OAuthSourcePropertyMapping Serializer + properties: + managed: + type: string + nullable: true + minLength: 1 + title: Managed by authentik + description: Objects that are managed by authentik. These objects are created + and updated automatically. This flag only indicates that an object can + be overwritten by migrations. You can still modify the objects via the + API, but expect changes to be overwritten in a later update. + name: + type: string + minLength: 1 + expression: + type: string + minLength: 1 PatchedOAuthSourceRequest: type: object description: OAuth Source Serializer @@ -44271,6 +45539,11 @@ components: user_path_template: type: string minLength: 1 + group_matching_mode: + allOf: + - $ref: '#/components/schemas/GroupMatchingModeEnum' + description: How the source determines if an existing group should be used + or a new group created. provider_type: $ref: '#/components/schemas/ProviderTypeEnum' request_token_url: @@ -44963,6 +46236,25 @@ components: default_relay_state: type: string description: Default relay_state value for IDP-initiated logins + PatchedSAMLSourcePropertyMappingRequest: + type: object + description: SAMLSourcePropertyMapping Serializer + properties: + managed: + type: string + nullable: true + minLength: 1 + title: Managed by authentik + description: Objects that are managed by authentik. These objects are created + and updated automatically. This flag only indicates that an object can + be overwritten by migrations. You can still modify the objects via the + API, but expect changes to be overwritten in a later update. + name: + type: string + minLength: 1 + expression: + type: string + minLength: 1 PatchedSAMLSourceRequest: type: object description: SAMLSource Serializer @@ -45009,6 +46301,11 @@ components: user_path_template: type: string minLength: 1 + group_matching_mode: + allOf: + - $ref: '#/components/schemas/GroupMatchingModeEnum' + description: How the source determines if an existing group should be used + or a new group created. pre_authentication_flow: type: string format: uuid @@ -47823,6 +49120,11 @@ components: icon: type: string readOnly: true + group_matching_mode: + allOf: + - $ref: '#/components/schemas/GroupMatchingModeEnum' + description: How the source determines if an existing group should be used + or a new group created. pre_authentication_flow: type: string format: uuid @@ -47888,6 +49190,73 @@ components: - sso_url - verbose_name - verbose_name_plural + SAMLSourcePropertyMapping: + type: object + description: SAMLSourcePropertyMapping Serializer + properties: + pk: + type: string + format: uuid + readOnly: true + title: Pm uuid + managed: + type: string + nullable: true + title: Managed by authentik + description: Objects that are managed by authentik. These objects are created + and updated automatically. This flag only indicates that an object can + be overwritten by migrations. You can still modify the objects via the + API, but expect changes to be overwritten in a later update. + name: + type: string + expression: + type: string + component: + type: string + description: Get object's component so that we know how to edit the object + readOnly: true + verbose_name: + type: string + description: Return object's verbose_name + readOnly: true + verbose_name_plural: + type: string + description: Return object's plural verbose_name + readOnly: true + meta_model_name: + type: string + description: Return internal model name + readOnly: true + required: + - component + - expression + - meta_model_name + - name + - pk + - verbose_name + - verbose_name_plural + SAMLSourcePropertyMappingRequest: + type: object + description: SAMLSourcePropertyMapping Serializer + properties: + managed: + type: string + nullable: true + minLength: 1 + title: Managed by authentik + description: Objects that are managed by authentik. These objects are created + and updated automatically. This flag only indicates that an object can + be overwritten by migrations. You can still modify the objects via the + API, but expect changes to be overwritten in a later update. + name: + type: string + minLength: 1 + expression: + type: string + minLength: 1 + required: + - expression + - name SAMLSourceRequest: type: object description: SAMLSource Serializer @@ -47934,6 +49303,11 @@ components: user_path_template: type: string minLength: 1 + group_matching_mode: + allOf: + - $ref: '#/components/schemas/GroupMatchingModeEnum' + description: How the source determines if an existing group should be used + or a new group created. pre_authentication_flow: type: string format: uuid @@ -49903,7 +51277,6 @@ components: readOnly: true name: type: string - maxLength: 80 is_superuser: type: boolean description: Users added to this group will be superusers. @@ -49930,7 +51303,6 @@ components: name: type: string minLength: 1 - maxLength: 80 is_superuser: type: boolean description: Users added to this group will be superusers. diff --git a/tests/e2e/test_source_oauth_oauth1.py b/tests/e2e/test_source_oauth_oauth1.py index cbeb66edfd..9ebc02d811 100644 --- a/tests/e2e/test_source_oauth_oauth1.py +++ b/tests/e2e/test_source_oauth_oauth1.py @@ -25,16 +25,6 @@ class OAuth1Callback(OAuthCallback): def get_user_id(self, info: dict[str, str]) -> str: return info.get("id") - def get_user_enroll_context( - self, - info: dict[str, Any], - ) -> dict[str, Any]: - return { - "username": info.get("screen_name"), - "email": info.get("email"), - "name": info.get("name"), - } - @registry.register() class OAUth1Type(SourceType): @@ -50,6 +40,13 @@ class OAUth1Type(SourceType): profile_url = "http://localhost:5001/api/me" urls_customizable = False + def get_base_user_properties(self, info: dict[str, Any], **kwargs) -> dict[str, Any]: + return { + "username": info.get("screen_name"), + "email": info.get("email"), + "name": info.get("name"), + } + class TestSourceOAuth1(SeleniumTestCase): """Test OAuth1 Source""" diff --git a/web/src/admin/property-mappings/PropertyMappingLDAPSourceForm.ts b/web/src/admin/property-mappings/PropertyMappingLDAPSourceForm.ts index d64709428c..a8642aa196 100644 --- a/web/src/admin/property-mappings/PropertyMappingLDAPSourceForm.ts +++ b/web/src/admin/property-mappings/PropertyMappingLDAPSourceForm.ts @@ -57,7 +57,9 @@ export class PropertyMappingLDAPSourceForm extends BasePropertyMappingForm ${msg("See documentation for a list of all variables.")} diff --git a/web/src/admin/property-mappings/PropertyMappingListPage.ts b/web/src/admin/property-mappings/PropertyMappingListPage.ts index 8db861924b..5d4b32450c 100644 --- a/web/src/admin/property-mappings/PropertyMappingListPage.ts +++ b/web/src/admin/property-mappings/PropertyMappingListPage.ts @@ -2,9 +2,11 @@ import "@goauthentik/admin/property-mappings/PropertyMappingGoogleWorkspaceForm" import "@goauthentik/admin/property-mappings/PropertyMappingLDAPSourceForm"; import "@goauthentik/admin/property-mappings/PropertyMappingMicrosoftEntraForm"; import "@goauthentik/admin/property-mappings/PropertyMappingNotification"; +import "@goauthentik/admin/property-mappings/PropertyMappingOAuthSourceForm"; import "@goauthentik/admin/property-mappings/PropertyMappingRACForm"; import "@goauthentik/admin/property-mappings/PropertyMappingRadiusForm"; import "@goauthentik/admin/property-mappings/PropertyMappingSAMLForm"; +import "@goauthentik/admin/property-mappings/PropertyMappingSAMLSourceForm"; import "@goauthentik/admin/property-mappings/PropertyMappingSCIMForm"; import "@goauthentik/admin/property-mappings/PropertyMappingSCIMSourceForm"; import "@goauthentik/admin/property-mappings/PropertyMappingScopeForm"; diff --git a/web/src/admin/property-mappings/PropertyMappingOAuthSourceForm.ts b/web/src/admin/property-mappings/PropertyMappingOAuthSourceForm.ts new file mode 100644 index 0000000000..1899b7d713 --- /dev/null +++ b/web/src/admin/property-mappings/PropertyMappingOAuthSourceForm.ts @@ -0,0 +1,75 @@ +import { BasePropertyMappingForm } from "@goauthentik/admin/property-mappings/BasePropertyMappingForm"; +import { DEFAULT_CONFIG } from "@goauthentik/common/api/config"; +import { docLink } from "@goauthentik/common/global"; +import "@goauthentik/elements/CodeMirror"; +import { CodeMirrorMode } from "@goauthentik/elements/CodeMirror"; +import "@goauthentik/elements/forms/HorizontalFormElement"; + +import { msg } from "@lit/localize"; +import { TemplateResult, html } from "lit"; +import { customElement } from "lit/decorators.js"; +import { ifDefined } from "lit/directives/if-defined.js"; + +import { OAuthSourcePropertyMapping, PropertymappingsApi } from "@goauthentik/api"; + +@customElement("ak-property-mapping-oauth-source-form") +export class PropertyMappingOAuthSourceForm extends BasePropertyMappingForm { + loadInstance(pk: string): Promise { + return new PropertymappingsApi(DEFAULT_CONFIG).propertymappingsSourceOauthRetrieve({ + pmUuid: pk, + }); + } + + async send(data: OAuthSourcePropertyMapping): Promise { + if (this.instance) { + return new PropertymappingsApi(DEFAULT_CONFIG).propertymappingsSourceOauthUpdate({ + pmUuid: this.instance.pk, + oAuthSourcePropertyMappingRequest: data, + }); + } else { + return new PropertymappingsApi(DEFAULT_CONFIG).propertymappingsSourceOauthCreate({ + oAuthSourcePropertyMappingRequest: data, + }); + } + } + + renderForm(): TemplateResult { + return html` + + + + + +

+ ${msg("Expression using Python.")} + + ${msg("See documentation for a list of all variables.")} + +

+
`; + } +} + +declare global { + interface HTMLElementTagNameMap { + "ak-property-mapping-oauth-source-form": PropertyMappingOAuthSourceForm; + } +} diff --git a/web/src/admin/property-mappings/PropertyMappingSAMLSourceForm.ts b/web/src/admin/property-mappings/PropertyMappingSAMLSourceForm.ts new file mode 100644 index 0000000000..9c01363fdb --- /dev/null +++ b/web/src/admin/property-mappings/PropertyMappingSAMLSourceForm.ts @@ -0,0 +1,75 @@ +import { BasePropertyMappingForm } from "@goauthentik/admin/property-mappings/BasePropertyMappingForm"; +import { DEFAULT_CONFIG } from "@goauthentik/common/api/config"; +import { docLink } from "@goauthentik/common/global"; +import "@goauthentik/elements/CodeMirror"; +import { CodeMirrorMode } from "@goauthentik/elements/CodeMirror"; +import "@goauthentik/elements/forms/HorizontalFormElement"; + +import { msg } from "@lit/localize"; +import { TemplateResult, html } from "lit"; +import { customElement } from "lit/decorators.js"; +import { ifDefined } from "lit/directives/if-defined.js"; + +import { PropertymappingsApi, SAMLSourcePropertyMapping } from "@goauthentik/api"; + +@customElement("ak-property-mapping-saml-source-form") +export class PropertyMappingSAMLSourceForm extends BasePropertyMappingForm { + loadInstance(pk: string): Promise { + return new PropertymappingsApi(DEFAULT_CONFIG).propertymappingsSourceSamlRetrieve({ + pmUuid: pk, + }); + } + + async send(data: SAMLSourcePropertyMapping): Promise { + if (this.instance) { + return new PropertymappingsApi(DEFAULT_CONFIG).propertymappingsSourceSamlUpdate({ + pmUuid: this.instance.pk, + sAMLSourcePropertyMappingRequest: data, + }); + } else { + return new PropertymappingsApi(DEFAULT_CONFIG).propertymappingsSourceSamlCreate({ + sAMLSourcePropertyMappingRequest: data, + }); + } + } + + renderForm(): TemplateResult { + return html` + + + + + +

+ ${msg("Expression using Python.")} + + ${msg("See documentation for a list of all variables.")} + +

+
`; + } +} + +declare global { + interface HTMLElementTagNameMap { + "ak-property-mapping-saml-source-form": PropertyMappingSAMLSourceForm; + } +} diff --git a/web/src/admin/property-mappings/PropertyMappingWizard.ts b/web/src/admin/property-mappings/PropertyMappingWizard.ts index 9e06c8d363..c2ec879367 100644 --- a/web/src/admin/property-mappings/PropertyMappingWizard.ts +++ b/web/src/admin/property-mappings/PropertyMappingWizard.ts @@ -1,7 +1,9 @@ import "@goauthentik/admin/property-mappings/PropertyMappingLDAPSourceForm"; import "@goauthentik/admin/property-mappings/PropertyMappingNotification"; +import "@goauthentik/admin/property-mappings/PropertyMappingOAuthSourceForm"; import "@goauthentik/admin/property-mappings/PropertyMappingRACForm"; import "@goauthentik/admin/property-mappings/PropertyMappingSAMLForm"; +import "@goauthentik/admin/property-mappings/PropertyMappingSAMLSourceForm"; import "@goauthentik/admin/property-mappings/PropertyMappingSCIMSourceForm"; import "@goauthentik/admin/property-mappings/PropertyMappingScopeForm"; import "@goauthentik/admin/property-mappings/PropertyMappingTestForm"; diff --git a/web/src/admin/sources/oauth/OAuthSourceForm.ts b/web/src/admin/sources/oauth/OAuthSourceForm.ts index 36ca76afb6..4ec1e23c48 100644 --- a/web/src/admin/sources/oauth/OAuthSourceForm.ts +++ b/web/src/admin/sources/oauth/OAuthSourceForm.ts @@ -1,7 +1,10 @@ import "@goauthentik/admin/common/ak-flow-search/ak-source-flow-search"; import { iconHelperText, placeholderHelperText } from "@goauthentik/admin/helperText"; import { BaseSourceForm } from "@goauthentik/admin/sources/BaseSourceForm"; -import { UserMatchingModeToLabel } from "@goauthentik/admin/sources/oauth/utils"; +import { + GroupMatchingModeToLabel, + UserMatchingModeToLabel, +} from "@goauthentik/admin/sources/oauth/utils"; import { DEFAULT_CONFIG, config } from "@goauthentik/common/api/config"; import { first } from "@goauthentik/common/utils"; import "@goauthentik/elements/CodeMirror"; @@ -10,6 +13,8 @@ import { CapabilitiesEnum, WithCapabilitiesConfig, } from "@goauthentik/elements/Interface/capabilitiesProvider"; +import "@goauthentik/elements/ak-dual-select/ak-dual-select-dynamic-selected-provider.js"; +import { DualSelectPair } from "@goauthentik/elements/ak-dual-select/types.js"; import "@goauthentik/elements/forms/FormGroup"; import "@goauthentik/elements/forms/HorizontalFormElement"; import "@goauthentik/elements/forms/SearchSelect"; @@ -21,14 +26,39 @@ import { ifDefined } from "lit/directives/if-defined.js"; import { FlowsInstancesListDesignationEnum, + GroupMatchingModeEnum, OAuthSource, + OAuthSourcePropertyMapping, OAuthSourceRequest, + PropertymappingsApi, ProviderTypeEnum, SourceType, SourcesApi, UserMatchingModeEnum, } from "@goauthentik/api"; +async function propertyMappingsProvider(page = 1, search = "") { + const propertyMappings = await new PropertymappingsApi( + DEFAULT_CONFIG, + ).propertymappingsSourceOauthList({ + ordering: "managed", + pageSize: 20, + search: search.trim(), + page, + }); + return { + pagination: propertyMappings.pagination, + options: propertyMappings.results.map((m) => [m.pk, m.name, m.name, m]), + }; +} + +function makePropertyMappingsSelector(instanceMappings?: string[]) { + const localMappings = instanceMappings ? new Set(instanceMappings) : undefined; + return localMappings + ? ([pk, _]: DualSelectPair) => localMappings.has(pk) + : ([_0, _1, _2, _]: DualSelectPair) => false; +} + @customElement("ak-source-oauth-form") export class OAuthSourceForm extends WithCapabilitiesConfig(BaseSourceForm) { async loadInstance(pk: string): Promise { @@ -40,6 +70,8 @@ export class OAuthSourceForm extends WithCapabilitiesConfig(BaseSourceForm + + + ${this.renderUrlOptions()} + + ${msg("OAuth Attribute mapping")} +
+ + +

+ ${msg("Property mappings for user creation.")} +

+
+ + +

+ ${msg("Property mappings for group creation.")} +

+
+
+
${msg("Flow settings")}
diff --git a/web/src/admin/sources/oauth/utils.ts b/web/src/admin/sources/oauth/utils.ts index fab271f197..c6825455ed 100644 --- a/web/src/admin/sources/oauth/utils.ts +++ b/web/src/admin/sources/oauth/utils.ts @@ -1,6 +1,6 @@ import { msg } from "@lit/localize"; -import { UserMatchingModeEnum } from "@goauthentik/api"; +import { GroupMatchingModeEnum, UserMatchingModeEnum } from "@goauthentik/api"; export function UserMatchingModeToLabel(mode?: UserMatchingModeEnum): string { if (!mode) return ""; @@ -27,3 +27,19 @@ export function UserMatchingModeToLabel(mode?: UserMatchingModeEnum): string { return msg("Unknown user matching mode"); } } + +export function GroupMatchingModeToLabel(mode?: GroupMatchingModeEnum): string { + if (!mode) return ""; + switch (mode) { + case GroupMatchingModeEnum.Identifier: + return msg("Link users on unique identifier"); + case GroupMatchingModeEnum.NameLink: + return msg( + "Link to a group with identical name. Can have security implications when a group is used with another source", + ); + case GroupMatchingModeEnum.NameDeny: + return msg("Use the group's name, but deny enrollment when the name already exists"); + case UserMatchingModeEnum.UnknownDefaultOpenApi: + return msg("Unknown user matching mode"); + } +} diff --git a/web/src/admin/sources/saml/SAMLSourceForm.ts b/web/src/admin/sources/saml/SAMLSourceForm.ts index 0800c97678..18d11dde72 100644 --- a/web/src/admin/sources/saml/SAMLSourceForm.ts +++ b/web/src/admin/sources/saml/SAMLSourceForm.ts @@ -2,13 +2,18 @@ import "@goauthentik/admin/common/ak-crypto-certificate-search"; import "@goauthentik/admin/common/ak-flow-search/ak-source-flow-search"; import { iconHelperText, placeholderHelperText } from "@goauthentik/admin/helperText"; import { BaseSourceForm } from "@goauthentik/admin/sources/BaseSourceForm"; -import { UserMatchingModeToLabel } from "@goauthentik/admin/sources/oauth/utils"; +import { + GroupMatchingModeToLabel, + UserMatchingModeToLabel, +} from "@goauthentik/admin/sources/oauth/utils"; import { DEFAULT_CONFIG, config } from "@goauthentik/common/api/config"; import { first } from "@goauthentik/common/utils"; import { CapabilitiesEnum, WithCapabilitiesConfig, } from "@goauthentik/elements/Interface/capabilitiesProvider"; +import "@goauthentik/elements/ak-dual-select/ak-dual-select-dynamic-selected-provider.js"; +import { DualSelectPair } from "@goauthentik/elements/ak-dual-select/types.js"; import "@goauthentik/elements/forms/FormGroup"; import "@goauthentik/elements/forms/HorizontalFormElement"; import "@goauthentik/elements/forms/Radio"; @@ -23,13 +28,38 @@ import { BindingTypeEnum, DigestAlgorithmEnum, FlowsInstancesListDesignationEnum, + GroupMatchingModeEnum, NameIdPolicyEnum, + PropertymappingsApi, SAMLSource, + SAMLSourcePropertyMapping, SignatureAlgorithmEnum, SourcesApi, UserMatchingModeEnum, } from "@goauthentik/api"; +async function propertyMappingsProvider(page = 1, search = "") { + const propertyMappings = await new PropertymappingsApi( + DEFAULT_CONFIG, + ).propertymappingsSourceSamlList({ + ordering: "managed", + pageSize: 20, + search: search.trim(), + page, + }); + return { + pagination: propertyMappings.pagination, + options: propertyMappings.results.map((m) => [m.pk, m.name, m.name, m]), + }; +} + +function makePropertyMappingsSelector(instanceMappings?: string[]) { + const localMappings = instanceMappings ? new Set(instanceMappings) : undefined; + return localMappings + ? ([pk, _]: DualSelectPair) => localMappings.has(pk) + : ([_0, _1, _2, _]: DualSelectPair) => false; +} + @customElement("ak-source-saml-form") export class SAMLSourceForm extends WithCapabilitiesConfig(BaseSourceForm) { @state() @@ -151,6 +181,35 @@ export class SAMLSourceForm extends WithCapabilitiesConfig(BaseSourceForm + + + ${this.can(CapabilitiesEnum.CanSaveMedia) ? html` @@ -451,6 +510,43 @@ export class SAMLSourceForm extends WithCapabilitiesConfig(BaseSourceForm
+ + ${msg("SAML Attribute mapping")} +
+ + +

+ ${msg("Property mappings for user creation.")} +

+
+ + +

+ ${msg("Property mappings for group creation.")} +

+
+
+
${msg("Flow settings")}