core: extract object matching from flow manager (#11458)
This commit is contained in:
		 Marc 'risson' Schmitt
					Marc 'risson' Schmitt
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							3262e70eac
						
					
				
				
					commit
					b57df12ace
				
			| @ -1,11 +1,9 @@ | ||||
| """Source decision helper""" | ||||
|  | ||||
| from enum import Enum | ||||
| from typing import Any | ||||
|  | ||||
| from django.contrib import messages | ||||
| from django.db import IntegrityError, transaction | ||||
| from django.db.models.query_utils import Q | ||||
| from django.http import HttpRequest, HttpResponse | ||||
| from django.shortcuts import redirect | ||||
| from django.urls import reverse | ||||
| @ -16,12 +14,11 @@ from authentik.core.models import ( | ||||
|     Group, | ||||
|     GroupSourceConnection, | ||||
|     Source, | ||||
|     SourceGroupMatchingModes, | ||||
|     SourceUserMatchingModes, | ||||
|     User, | ||||
|     UserSourceConnection, | ||||
| ) | ||||
| from authentik.core.sources.mapper import SourceMapper | ||||
| from authentik.core.sources.matcher import Action, SourceMatcher | ||||
| from authentik.core.sources.stage import ( | ||||
|     PLAN_CONTEXT_SOURCES_CONNECTION, | ||||
|     PostSourceStage, | ||||
| @ -54,16 +51,6 @@ SESSION_KEY_OVERRIDE_FLOW_TOKEN = "authentik/flows/source_override_flow_token" | ||||
| PLAN_CONTEXT_SOURCE_GROUPS = "source_groups" | ||||
|  | ||||
|  | ||||
| class Action(Enum): | ||||
|     """Actions that can be decided based on the request | ||||
|     and source settings""" | ||||
|  | ||||
|     LINK = "link" | ||||
|     AUTH = "auth" | ||||
|     ENROLL = "enroll" | ||||
|     DENY = "deny" | ||||
|  | ||||
|  | ||||
| class MessageStage(StageView): | ||||
|     """Show a pre-configured message after the flow is done""" | ||||
|  | ||||
| @ -86,6 +73,7 @@ class SourceFlowManager: | ||||
|  | ||||
|     source: Source | ||||
|     mapper: SourceMapper | ||||
|     matcher: SourceMatcher | ||||
|     request: HttpRequest | ||||
|  | ||||
|     identifier: str | ||||
| @ -108,6 +96,9 @@ class SourceFlowManager: | ||||
|     ) -> None: | ||||
|         self.source = source | ||||
|         self.mapper = SourceMapper(self.source) | ||||
|         self.matcher = SourceMatcher( | ||||
|             self.source, self.user_connection_type, self.group_connection_type | ||||
|         ) | ||||
|         self.request = request | ||||
|         self.identifier = identifier | ||||
|         self.user_info = user_info | ||||
| @ -131,66 +122,19 @@ class SourceFlowManager: | ||||
|  | ||||
|     def get_action(self, **kwargs) -> tuple[Action, UserSourceConnection | None]:  # noqa: PLR0911 | ||||
|         """decide which action should be taken""" | ||||
|         new_connection = self.user_connection_type(source=self.source, identifier=self.identifier) | ||||
|         # When request is authenticated, always link | ||||
|         if self.request.user.is_authenticated: | ||||
|             new_connection = self.user_connection_type( | ||||
|                 source=self.source, identifier=self.identifier | ||||
|             ) | ||||
|             new_connection.user = self.request.user | ||||
|             new_connection = self.update_user_connection(new_connection, **kwargs) | ||||
|             return Action.LINK, new_connection | ||||
|  | ||||
|         existing_connections = self.user_connection_type.objects.filter( | ||||
|             source=self.source, identifier=self.identifier | ||||
|         ) | ||||
|         if existing_connections.exists(): | ||||
|             connection = existing_connections.first() | ||||
|             return Action.AUTH, self.update_user_connection(connection, **kwargs) | ||||
|         # No connection exists, but we match on identifier, so enroll | ||||
|         if self.source.user_matching_mode == SourceUserMatchingModes.IDENTIFIER: | ||||
|             # We don't save the connection here cause it doesn't have a user assigned yet | ||||
|             return Action.ENROLL, self.update_user_connection(new_connection, **kwargs) | ||||
|  | ||||
|         # Check for existing users with matching attributes | ||||
|         query = Q() | ||||
|         # Either query existing user based on email or username | ||||
|         if self.source.user_matching_mode in [ | ||||
|             SourceUserMatchingModes.EMAIL_LINK, | ||||
|             SourceUserMatchingModes.EMAIL_DENY, | ||||
|         ]: | ||||
|             if not self.user_properties.get("email", None): | ||||
|                 self._logger.warning("Refusing to use none email") | ||||
|                 return Action.DENY, None | ||||
|             query = Q(email__exact=self.user_properties.get("email", None)) | ||||
|         if self.source.user_matching_mode in [ | ||||
|             SourceUserMatchingModes.USERNAME_LINK, | ||||
|             SourceUserMatchingModes.USERNAME_DENY, | ||||
|         ]: | ||||
|             if not self.user_properties.get("username", None): | ||||
|                 self._logger.warning("Refusing to use none username") | ||||
|                 return Action.DENY, None | ||||
|             query = Q(username__exact=self.user_properties.get("username", None)) | ||||
|         self._logger.debug("trying to link with existing user", query=query) | ||||
|         matching_users = User.objects.filter(query) | ||||
|         # No matching users, always enroll | ||||
|         if not matching_users.exists(): | ||||
|             self._logger.debug("no matching users found, enrolling") | ||||
|             return Action.ENROLL, self.update_user_connection(new_connection, **kwargs) | ||||
|  | ||||
|         user = matching_users.first() | ||||
|         if self.source.user_matching_mode in [ | ||||
|             SourceUserMatchingModes.EMAIL_LINK, | ||||
|             SourceUserMatchingModes.USERNAME_LINK, | ||||
|         ]: | ||||
|             new_connection.user = user | ||||
|             new_connection = self.update_user_connection(new_connection, **kwargs) | ||||
|             return Action.LINK, new_connection | ||||
|         if self.source.user_matching_mode in [ | ||||
|             SourceUserMatchingModes.EMAIL_DENY, | ||||
|             SourceUserMatchingModes.USERNAME_DENY, | ||||
|         ]: | ||||
|             self._logger.info("denying source because user exists", user=user) | ||||
|             return Action.DENY, None | ||||
|         # Should never get here as default enroll case is returned above. | ||||
|         return Action.DENY, None  # pragma: no cover | ||||
|         action, connection = self.matcher.get_user_action(self.identifier, self.user_properties) | ||||
|         if connection: | ||||
|             connection = self.update_user_connection(connection, **kwargs) | ||||
|         return action, connection | ||||
|  | ||||
|     def update_user_connection( | ||||
|         self, connection: UserSourceConnection, **kwargs | ||||
| @ -408,74 +352,16 @@ class SourceFlowManager: | ||||
| class GroupUpdateStage(StageView): | ||||
|     """Dynamically injected stage which updates the user after enrollment/authentication.""" | ||||
|  | ||||
|     def get_action( | ||||
|         self, group_id: str, group_properties: dict[str, Any | dict[str, Any]] | ||||
|     ) -> tuple[Action, GroupSourceConnection | None]: | ||||
|         """decide which action should be taken""" | ||||
|         new_connection = self.group_connection_type(source=self.source, identifier=group_id) | ||||
|  | ||||
|         existing_connections = self.group_connection_type.objects.filter( | ||||
|             source=self.source, identifier=group_id | ||||
|         ) | ||||
|         if existing_connections.exists(): | ||||
|             return Action.LINK, existing_connections.first() | ||||
|         # No connection exists, but we match on identifier, so enroll | ||||
|         if self.source.group_matching_mode == SourceGroupMatchingModes.IDENTIFIER: | ||||
|             # We don't save the connection here cause it doesn't have a user assigned yet | ||||
|             return Action.ENROLL, new_connection | ||||
|  | ||||
|         # Check for existing groups with matching attributes | ||||
|         query = Q() | ||||
|         if self.source.group_matching_mode in [ | ||||
|             SourceGroupMatchingModes.NAME_LINK, | ||||
|             SourceGroupMatchingModes.NAME_DENY, | ||||
|         ]: | ||||
|             if not group_properties.get("name", None): | ||||
|                 LOGGER.warning( | ||||
|                     "Refusing to use none group name", source=self.source, group_id=group_id | ||||
|                 ) | ||||
|                 return Action.DENY, None | ||||
|             query = Q(name__exact=group_properties.get("name")) | ||||
|         LOGGER.debug( | ||||
|             "trying to link with existing group", source=self.source, query=query, group_id=group_id | ||||
|         ) | ||||
|         matching_groups = Group.objects.filter(query) | ||||
|         # No matching groups, always enroll | ||||
|         if not matching_groups.exists(): | ||||
|             LOGGER.debug( | ||||
|                 "no matching groups found, enrolling", source=self.source, group_id=group_id | ||||
|             ) | ||||
|             return Action.ENROLL, new_connection | ||||
|  | ||||
|         group = matching_groups.first() | ||||
|         if self.source.group_matching_mode in [ | ||||
|             SourceGroupMatchingModes.NAME_LINK, | ||||
|         ]: | ||||
|             new_connection.group = group | ||||
|             return Action.LINK, new_connection | ||||
|         if self.source.group_matching_mode in [ | ||||
|             SourceGroupMatchingModes.NAME_DENY, | ||||
|         ]: | ||||
|             LOGGER.info( | ||||
|                 "denying source because group exists", | ||||
|                 source=self.source, | ||||
|                 group=group, | ||||
|                 group_id=group_id, | ||||
|             ) | ||||
|             return Action.DENY, None | ||||
|         # Should never get here as default enroll case is returned above. | ||||
|         return Action.DENY, None  # pragma: no cover | ||||
|  | ||||
|     def handle_group( | ||||
|         self, group_id: str, group_properties: dict[str, Any | dict[str, Any]] | ||||
|     ) -> Group | None: | ||||
|         action, connection = self.get_action(group_id, group_properties) | ||||
|         action, connection = self.matcher.get_group_action(group_id, group_properties) | ||||
|         if action == Action.ENROLL: | ||||
|             group = Group.objects.create(**group_properties) | ||||
|             connection.group = group | ||||
|             connection.save() | ||||
|             return group | ||||
|         elif action == Action.LINK: | ||||
|         elif action in (Action.LINK, Action.AUTH): | ||||
|             group = connection.group | ||||
|             group.update_attributes(group_properties) | ||||
|             connection.save() | ||||
| @ -489,6 +375,7 @@ class GroupUpdateStage(StageView): | ||||
|         self.group_connection_type: GroupSourceConnection = ( | ||||
|             self.executor.current_stage.group_connection_type | ||||
|         ) | ||||
|         self.matcher = SourceMatcher(self.source, None, self.group_connection_type) | ||||
|  | ||||
|         raw_groups: dict[str, dict[str, Any | dict[str, Any]]] = self.executor.plan.context[ | ||||
|             PLAN_CONTEXT_SOURCE_GROUPS | ||||
|  | ||||
							
								
								
									
										152
									
								
								authentik/core/sources/matcher.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										152
									
								
								authentik/core/sources/matcher.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,152 @@ | ||||
| """Source user and group matching""" | ||||
|  | ||||
| from dataclasses import dataclass | ||||
| from enum import Enum | ||||
| from typing import Any | ||||
|  | ||||
| from django.db.models import Q | ||||
| from structlog import get_logger | ||||
|  | ||||
| from authentik.core.models import ( | ||||
|     Group, | ||||
|     GroupSourceConnection, | ||||
|     Source, | ||||
|     SourceGroupMatchingModes, | ||||
|     SourceUserMatchingModes, | ||||
|     User, | ||||
|     UserSourceConnection, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class Action(Enum): | ||||
|     """Actions that can be decided based on the request and source settings""" | ||||
|  | ||||
|     LINK = "link" | ||||
|     AUTH = "auth" | ||||
|     ENROLL = "enroll" | ||||
|     DENY = "deny" | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class MatchableProperty: | ||||
|     property: str | ||||
|     link_mode: SourceUserMatchingModes | SourceGroupMatchingModes | ||||
|     deny_mode: SourceUserMatchingModes | SourceGroupMatchingModes | ||||
|  | ||||
|  | ||||
| class SourceMatcher: | ||||
|     def __init__( | ||||
|         self, | ||||
|         source: Source, | ||||
|         user_connection_type: type[UserSourceConnection], | ||||
|         group_connection_type: type[GroupSourceConnection], | ||||
|     ): | ||||
|         self.source = source | ||||
|         self.user_connection_type = user_connection_type | ||||
|         self.group_connection_type = group_connection_type | ||||
|         self._logger = get_logger().bind(source=self.source) | ||||
|  | ||||
|     def get_action( | ||||
|         self, | ||||
|         object_type: type[User | Group], | ||||
|         matchable_properties: list[MatchableProperty], | ||||
|         identifier: str, | ||||
|         properties: dict[str, Any | dict[str, Any]], | ||||
|     ) -> tuple[Action, UserSourceConnection | GroupSourceConnection | None]: | ||||
|         connection_type = None | ||||
|         matching_mode = None | ||||
|         identifier_matching_mode = None | ||||
|         if object_type == User: | ||||
|             connection_type = self.user_connection_type | ||||
|             matching_mode = self.source.user_matching_mode | ||||
|             identifier_matching_mode = SourceUserMatchingModes.IDENTIFIER | ||||
|         if object_type == Group: | ||||
|             connection_type = self.group_connection_type | ||||
|             matching_mode = self.source.group_matching_mode | ||||
|             identifier_matching_mode = SourceGroupMatchingModes.IDENTIFIER | ||||
|         if not connection_type or not matching_mode or not identifier_matching_mode: | ||||
|             return Action.DENY, None | ||||
|  | ||||
|         new_connection = connection_type(source=self.source, identifier=identifier) | ||||
|  | ||||
|         existing_connections = connection_type.objects.filter( | ||||
|             source=self.source, identifier=identifier | ||||
|         ) | ||||
|         if existing_connections.exists(): | ||||
|             return Action.AUTH, existing_connections.first() | ||||
|         # No connection exists, but we match on identifier, so enroll | ||||
|         if matching_mode == identifier_matching_mode: | ||||
|             # We don't save the connection here cause it doesn't have a user/group assigned yet | ||||
|             return Action.ENROLL, new_connection | ||||
|  | ||||
|         # Check for existing users with matching attributes | ||||
|         query = Q() | ||||
|         for matchable_property in matchable_properties: | ||||
|             property = matchable_property.property | ||||
|             if matching_mode in [matchable_property.link_mode, matchable_property.deny_mode]: | ||||
|                 if not properties.get(property, None): | ||||
|                     self._logger.warning( | ||||
|                         "Refusing to use none property", identifier=identifier, property=property | ||||
|                     ) | ||||
|                     return Action.DENY, None | ||||
|                 query_args = { | ||||
|                     f"{property}__exact": properties[property], | ||||
|                 } | ||||
|                 query = Q(**query_args) | ||||
|         self._logger.debug( | ||||
|             "Trying to link with existing object", query=query, identifier=identifier | ||||
|         ) | ||||
|         matching_objects = object_type.objects.filter(query) | ||||
|         # Not matching objects, always enroll | ||||
|         if not matching_objects.exists(): | ||||
|             self._logger.debug("No matching objects found, enrolling") | ||||
|             return Action.ENROLL, new_connection | ||||
|  | ||||
|         obj = matching_objects.first() | ||||
|         if matching_mode in [mp.link_mode for mp in matchable_properties]: | ||||
|             attr = None | ||||
|             if object_type == User: | ||||
|                 attr = "user" | ||||
|             if object_type == Group: | ||||
|                 attr = "group" | ||||
|             setattr(new_connection, attr, obj) | ||||
|             return Action.LINK, new_connection | ||||
|         if matching_mode in [mp.deny_mode for mp in matchable_properties]: | ||||
|             self._logger.info("Denying source because object exists", obj=obj) | ||||
|             return Action.DENY, None | ||||
|  | ||||
|         # Should never get here as default enroll case is returned above. | ||||
|         return Action.DENY, None  # pragma: no cover | ||||
|  | ||||
|     def get_user_action( | ||||
|         self, identifier: str, properties: dict[str, Any | dict[str, Any]] | ||||
|     ) -> tuple[Action, UserSourceConnection | None]: | ||||
|         return self.get_action( | ||||
|             User, | ||||
|             [ | ||||
|                 MatchableProperty( | ||||
|                     "username", | ||||
|                     SourceUserMatchingModes.USERNAME_LINK, | ||||
|                     SourceUserMatchingModes.USERNAME_DENY, | ||||
|                 ), | ||||
|                 MatchableProperty( | ||||
|                     "email", SourceUserMatchingModes.EMAIL_LINK, SourceUserMatchingModes.EMAIL_DENY | ||||
|                 ), | ||||
|             ], | ||||
|             identifier, | ||||
|             properties, | ||||
|         ) | ||||
|  | ||||
|     def get_group_action( | ||||
|         self, identifier: str, properties: dict[str, Any | dict[str, Any]] | ||||
|     ) -> tuple[Action, GroupSourceConnection | None]: | ||||
|         return self.get_action( | ||||
|             Group, | ||||
|             [ | ||||
|                 MatchableProperty( | ||||
|                     "name", SourceGroupMatchingModes.NAME_LINK, SourceGroupMatchingModes.NAME_DENY | ||||
|                 ), | ||||
|             ], | ||||
|             identifier, | ||||
|             properties, | ||||
|         ) | ||||
		Reference in New Issue
	
	Block a user