core: extract object matching from flow manager (#11458)

This commit is contained in:
Marc 'risson' Schmitt
2024-10-17 14:21:39 +02:00
committed by GitHub
parent 3262e70eac
commit b57df12ace
2 changed files with 167 additions and 128 deletions

View File

@ -1,11 +1,9 @@
"""Source decision helper""" """Source decision helper"""
from enum import Enum
from typing import Any from typing import Any
from django.contrib import messages from django.contrib import messages
from django.db import IntegrityError, transaction from django.db import IntegrityError, transaction
from django.db.models.query_utils import Q
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.shortcuts import redirect from django.shortcuts import redirect
from django.urls import reverse from django.urls import reverse
@ -16,12 +14,11 @@ from authentik.core.models import (
Group, Group,
GroupSourceConnection, GroupSourceConnection,
Source, Source,
SourceGroupMatchingModes,
SourceUserMatchingModes,
User, User,
UserSourceConnection, UserSourceConnection,
) )
from authentik.core.sources.mapper import SourceMapper from authentik.core.sources.mapper import SourceMapper
from authentik.core.sources.matcher import Action, SourceMatcher
from authentik.core.sources.stage import ( from authentik.core.sources.stage import (
PLAN_CONTEXT_SOURCES_CONNECTION, PLAN_CONTEXT_SOURCES_CONNECTION,
PostSourceStage, PostSourceStage,
@ -54,16 +51,6 @@ SESSION_KEY_OVERRIDE_FLOW_TOKEN = "authentik/flows/source_override_flow_token"
PLAN_CONTEXT_SOURCE_GROUPS = "source_groups" PLAN_CONTEXT_SOURCE_GROUPS = "source_groups"
class Action(Enum):
"""Actions that can be decided based on the request
and source settings"""
LINK = "link"
AUTH = "auth"
ENROLL = "enroll"
DENY = "deny"
class MessageStage(StageView): class MessageStage(StageView):
"""Show a pre-configured message after the flow is done""" """Show a pre-configured message after the flow is done"""
@ -86,6 +73,7 @@ class SourceFlowManager:
source: Source source: Source
mapper: SourceMapper mapper: SourceMapper
matcher: SourceMatcher
request: HttpRequest request: HttpRequest
identifier: str identifier: str
@ -108,6 +96,9 @@ class SourceFlowManager:
) -> None: ) -> None:
self.source = source self.source = source
self.mapper = SourceMapper(self.source) self.mapper = SourceMapper(self.source)
self.matcher = SourceMatcher(
self.source, self.user_connection_type, self.group_connection_type
)
self.request = request self.request = request
self.identifier = identifier self.identifier = identifier
self.user_info = user_info self.user_info = user_info
@ -131,66 +122,19 @@ class SourceFlowManager:
def get_action(self, **kwargs) -> tuple[Action, UserSourceConnection | None]: # noqa: PLR0911 def get_action(self, **kwargs) -> tuple[Action, UserSourceConnection | None]: # noqa: PLR0911
"""decide which action should be taken""" """decide which action should be taken"""
new_connection = self.user_connection_type(source=self.source, identifier=self.identifier)
# When request is authenticated, always link # When request is authenticated, always link
if self.request.user.is_authenticated: if self.request.user.is_authenticated:
new_connection = self.user_connection_type(
source=self.source, identifier=self.identifier
)
new_connection.user = self.request.user new_connection.user = self.request.user
new_connection = self.update_user_connection(new_connection, **kwargs) new_connection = self.update_user_connection(new_connection, **kwargs)
return Action.LINK, new_connection return Action.LINK, new_connection
existing_connections = self.user_connection_type.objects.filter( action, connection = self.matcher.get_user_action(self.identifier, self.user_properties)
source=self.source, identifier=self.identifier if connection:
) connection = self.update_user_connection(connection, **kwargs)
if existing_connections.exists(): return action, connection
connection = existing_connections.first()
return Action.AUTH, self.update_user_connection(connection, **kwargs)
# No connection exists, but we match on identifier, so enroll
if self.source.user_matching_mode == SourceUserMatchingModes.IDENTIFIER:
# We don't save the connection here cause it doesn't have a user assigned yet
return Action.ENROLL, self.update_user_connection(new_connection, **kwargs)
# Check for existing users with matching attributes
query = Q()
# Either query existing user based on email or username
if self.source.user_matching_mode in [
SourceUserMatchingModes.EMAIL_LINK,
SourceUserMatchingModes.EMAIL_DENY,
]:
if not self.user_properties.get("email", None):
self._logger.warning("Refusing to use none email")
return Action.DENY, None
query = Q(email__exact=self.user_properties.get("email", None))
if self.source.user_matching_mode in [
SourceUserMatchingModes.USERNAME_LINK,
SourceUserMatchingModes.USERNAME_DENY,
]:
if not self.user_properties.get("username", None):
self._logger.warning("Refusing to use none username")
return Action.DENY, None
query = Q(username__exact=self.user_properties.get("username", None))
self._logger.debug("trying to link with existing user", query=query)
matching_users = User.objects.filter(query)
# No matching users, always enroll
if not matching_users.exists():
self._logger.debug("no matching users found, enrolling")
return Action.ENROLL, self.update_user_connection(new_connection, **kwargs)
user = matching_users.first()
if self.source.user_matching_mode in [
SourceUserMatchingModes.EMAIL_LINK,
SourceUserMatchingModes.USERNAME_LINK,
]:
new_connection.user = user
new_connection = self.update_user_connection(new_connection, **kwargs)
return Action.LINK, new_connection
if self.source.user_matching_mode in [
SourceUserMatchingModes.EMAIL_DENY,
SourceUserMatchingModes.USERNAME_DENY,
]:
self._logger.info("denying source because user exists", user=user)
return Action.DENY, None
# Should never get here as default enroll case is returned above.
return Action.DENY, None # pragma: no cover
def update_user_connection( def update_user_connection(
self, connection: UserSourceConnection, **kwargs self, connection: UserSourceConnection, **kwargs
@ -408,74 +352,16 @@ class SourceFlowManager:
class GroupUpdateStage(StageView): class GroupUpdateStage(StageView):
"""Dynamically injected stage which updates the user after enrollment/authentication.""" """Dynamically injected stage which updates the user after enrollment/authentication."""
def get_action(
self, group_id: str, group_properties: dict[str, Any | dict[str, Any]]
) -> tuple[Action, GroupSourceConnection | None]:
"""decide which action should be taken"""
new_connection = self.group_connection_type(source=self.source, identifier=group_id)
existing_connections = self.group_connection_type.objects.filter(
source=self.source, identifier=group_id
)
if existing_connections.exists():
return Action.LINK, existing_connections.first()
# No connection exists, but we match on identifier, so enroll
if self.source.group_matching_mode == SourceGroupMatchingModes.IDENTIFIER:
# We don't save the connection here cause it doesn't have a user assigned yet
return Action.ENROLL, new_connection
# Check for existing groups with matching attributes
query = Q()
if self.source.group_matching_mode in [
SourceGroupMatchingModes.NAME_LINK,
SourceGroupMatchingModes.NAME_DENY,
]:
if not group_properties.get("name", None):
LOGGER.warning(
"Refusing to use none group name", source=self.source, group_id=group_id
)
return Action.DENY, None
query = Q(name__exact=group_properties.get("name"))
LOGGER.debug(
"trying to link with existing group", source=self.source, query=query, group_id=group_id
)
matching_groups = Group.objects.filter(query)
# No matching groups, always enroll
if not matching_groups.exists():
LOGGER.debug(
"no matching groups found, enrolling", source=self.source, group_id=group_id
)
return Action.ENROLL, new_connection
group = matching_groups.first()
if self.source.group_matching_mode in [
SourceGroupMatchingModes.NAME_LINK,
]:
new_connection.group = group
return Action.LINK, new_connection
if self.source.group_matching_mode in [
SourceGroupMatchingModes.NAME_DENY,
]:
LOGGER.info(
"denying source because group exists",
source=self.source,
group=group,
group_id=group_id,
)
return Action.DENY, None
# Should never get here as default enroll case is returned above.
return Action.DENY, None # pragma: no cover
def handle_group( def handle_group(
self, group_id: str, group_properties: dict[str, Any | dict[str, Any]] self, group_id: str, group_properties: dict[str, Any | dict[str, Any]]
) -> Group | None: ) -> Group | None:
action, connection = self.get_action(group_id, group_properties) action, connection = self.matcher.get_group_action(group_id, group_properties)
if action == Action.ENROLL: if action == Action.ENROLL:
group = Group.objects.create(**group_properties) group = Group.objects.create(**group_properties)
connection.group = group connection.group = group
connection.save() connection.save()
return group return group
elif action == Action.LINK: elif action in (Action.LINK, Action.AUTH):
group = connection.group group = connection.group
group.update_attributes(group_properties) group.update_attributes(group_properties)
connection.save() connection.save()
@ -489,6 +375,7 @@ class GroupUpdateStage(StageView):
self.group_connection_type: GroupSourceConnection = ( self.group_connection_type: GroupSourceConnection = (
self.executor.current_stage.group_connection_type self.executor.current_stage.group_connection_type
) )
self.matcher = SourceMatcher(self.source, None, self.group_connection_type)
raw_groups: dict[str, dict[str, Any | dict[str, Any]]] = self.executor.plan.context[ raw_groups: dict[str, dict[str, Any | dict[str, Any]]] = self.executor.plan.context[
PLAN_CONTEXT_SOURCE_GROUPS PLAN_CONTEXT_SOURCE_GROUPS

View File

@ -0,0 +1,152 @@
"""Source user and group matching"""
from dataclasses import dataclass
from enum import Enum
from typing import Any
from django.db.models import Q
from structlog import get_logger
from authentik.core.models import (
Group,
GroupSourceConnection,
Source,
SourceGroupMatchingModes,
SourceUserMatchingModes,
User,
UserSourceConnection,
)
class Action(Enum):
"""Actions that can be decided based on the request and source settings"""
LINK = "link"
AUTH = "auth"
ENROLL = "enroll"
DENY = "deny"
@dataclass
class MatchableProperty:
property: str
link_mode: SourceUserMatchingModes | SourceGroupMatchingModes
deny_mode: SourceUserMatchingModes | SourceGroupMatchingModes
class SourceMatcher:
def __init__(
self,
source: Source,
user_connection_type: type[UserSourceConnection],
group_connection_type: type[GroupSourceConnection],
):
self.source = source
self.user_connection_type = user_connection_type
self.group_connection_type = group_connection_type
self._logger = get_logger().bind(source=self.source)
def get_action(
self,
object_type: type[User | Group],
matchable_properties: list[MatchableProperty],
identifier: str,
properties: dict[str, Any | dict[str, Any]],
) -> tuple[Action, UserSourceConnection | GroupSourceConnection | None]:
connection_type = None
matching_mode = None
identifier_matching_mode = None
if object_type == User:
connection_type = self.user_connection_type
matching_mode = self.source.user_matching_mode
identifier_matching_mode = SourceUserMatchingModes.IDENTIFIER
if object_type == Group:
connection_type = self.group_connection_type
matching_mode = self.source.group_matching_mode
identifier_matching_mode = SourceGroupMatchingModes.IDENTIFIER
if not connection_type or not matching_mode or not identifier_matching_mode:
return Action.DENY, None
new_connection = connection_type(source=self.source, identifier=identifier)
existing_connections = connection_type.objects.filter(
source=self.source, identifier=identifier
)
if existing_connections.exists():
return Action.AUTH, existing_connections.first()
# No connection exists, but we match on identifier, so enroll
if matching_mode == identifier_matching_mode:
# We don't save the connection here cause it doesn't have a user/group assigned yet
return Action.ENROLL, new_connection
# Check for existing users with matching attributes
query = Q()
for matchable_property in matchable_properties:
property = matchable_property.property
if matching_mode in [matchable_property.link_mode, matchable_property.deny_mode]:
if not properties.get(property, None):
self._logger.warning(
"Refusing to use none property", identifier=identifier, property=property
)
return Action.DENY, None
query_args = {
f"{property}__exact": properties[property],
}
query = Q(**query_args)
self._logger.debug(
"Trying to link with existing object", query=query, identifier=identifier
)
matching_objects = object_type.objects.filter(query)
# Not matching objects, always enroll
if not matching_objects.exists():
self._logger.debug("No matching objects found, enrolling")
return Action.ENROLL, new_connection
obj = matching_objects.first()
if matching_mode in [mp.link_mode for mp in matchable_properties]:
attr = None
if object_type == User:
attr = "user"
if object_type == Group:
attr = "group"
setattr(new_connection, attr, obj)
return Action.LINK, new_connection
if matching_mode in [mp.deny_mode for mp in matchable_properties]:
self._logger.info("Denying source because object exists", obj=obj)
return Action.DENY, None
# Should never get here as default enroll case is returned above.
return Action.DENY, None # pragma: no cover
def get_user_action(
self, identifier: str, properties: dict[str, Any | dict[str, Any]]
) -> tuple[Action, UserSourceConnection | None]:
return self.get_action(
User,
[
MatchableProperty(
"username",
SourceUserMatchingModes.USERNAME_LINK,
SourceUserMatchingModes.USERNAME_DENY,
),
MatchableProperty(
"email", SourceUserMatchingModes.EMAIL_LINK, SourceUserMatchingModes.EMAIL_DENY
),
],
identifier,
properties,
)
def get_group_action(
self, identifier: str, properties: dict[str, Any | dict[str, Any]]
) -> tuple[Action, GroupSourceConnection | None]:
return self.get_action(
Group,
[
MatchableProperty(
"name", SourceGroupMatchingModes.NAME_LINK, SourceGroupMatchingModes.NAME_DENY
),
],
identifier,
properties,
)