core: extract object matching from flow manager (#11458)
This commit is contained in:
committed by
GitHub
parent
3262e70eac
commit
b57df12ace
@ -1,11 +1,9 @@
|
|||||||
"""Source decision helper"""
|
"""Source decision helper"""
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from django.contrib import messages
|
from django.contrib import messages
|
||||||
from django.db import IntegrityError, transaction
|
from django.db import IntegrityError, transaction
|
||||||
from django.db.models.query_utils import Q
|
|
||||||
from django.http import HttpRequest, HttpResponse
|
from django.http import HttpRequest, HttpResponse
|
||||||
from django.shortcuts import redirect
|
from django.shortcuts import redirect
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
@ -16,12 +14,11 @@ from authentik.core.models import (
|
|||||||
Group,
|
Group,
|
||||||
GroupSourceConnection,
|
GroupSourceConnection,
|
||||||
Source,
|
Source,
|
||||||
SourceGroupMatchingModes,
|
|
||||||
SourceUserMatchingModes,
|
|
||||||
User,
|
User,
|
||||||
UserSourceConnection,
|
UserSourceConnection,
|
||||||
)
|
)
|
||||||
from authentik.core.sources.mapper import SourceMapper
|
from authentik.core.sources.mapper import SourceMapper
|
||||||
|
from authentik.core.sources.matcher import Action, SourceMatcher
|
||||||
from authentik.core.sources.stage import (
|
from authentik.core.sources.stage import (
|
||||||
PLAN_CONTEXT_SOURCES_CONNECTION,
|
PLAN_CONTEXT_SOURCES_CONNECTION,
|
||||||
PostSourceStage,
|
PostSourceStage,
|
||||||
@ -54,16 +51,6 @@ SESSION_KEY_OVERRIDE_FLOW_TOKEN = "authentik/flows/source_override_flow_token"
|
|||||||
PLAN_CONTEXT_SOURCE_GROUPS = "source_groups"
|
PLAN_CONTEXT_SOURCE_GROUPS = "source_groups"
|
||||||
|
|
||||||
|
|
||||||
class Action(Enum):
|
|
||||||
"""Actions that can be decided based on the request
|
|
||||||
and source settings"""
|
|
||||||
|
|
||||||
LINK = "link"
|
|
||||||
AUTH = "auth"
|
|
||||||
ENROLL = "enroll"
|
|
||||||
DENY = "deny"
|
|
||||||
|
|
||||||
|
|
||||||
class MessageStage(StageView):
|
class MessageStage(StageView):
|
||||||
"""Show a pre-configured message after the flow is done"""
|
"""Show a pre-configured message after the flow is done"""
|
||||||
|
|
||||||
@ -86,6 +73,7 @@ class SourceFlowManager:
|
|||||||
|
|
||||||
source: Source
|
source: Source
|
||||||
mapper: SourceMapper
|
mapper: SourceMapper
|
||||||
|
matcher: SourceMatcher
|
||||||
request: HttpRequest
|
request: HttpRequest
|
||||||
|
|
||||||
identifier: str
|
identifier: str
|
||||||
@ -108,6 +96,9 @@ class SourceFlowManager:
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.source = source
|
self.source = source
|
||||||
self.mapper = SourceMapper(self.source)
|
self.mapper = SourceMapper(self.source)
|
||||||
|
self.matcher = SourceMatcher(
|
||||||
|
self.source, self.user_connection_type, self.group_connection_type
|
||||||
|
)
|
||||||
self.request = request
|
self.request = request
|
||||||
self.identifier = identifier
|
self.identifier = identifier
|
||||||
self.user_info = user_info
|
self.user_info = user_info
|
||||||
@ -131,66 +122,19 @@ class SourceFlowManager:
|
|||||||
|
|
||||||
def get_action(self, **kwargs) -> tuple[Action, UserSourceConnection | None]: # noqa: PLR0911
|
def get_action(self, **kwargs) -> tuple[Action, UserSourceConnection | None]: # noqa: PLR0911
|
||||||
"""decide which action should be taken"""
|
"""decide which action should be taken"""
|
||||||
new_connection = self.user_connection_type(source=self.source, identifier=self.identifier)
|
|
||||||
# When request is authenticated, always link
|
# When request is authenticated, always link
|
||||||
if self.request.user.is_authenticated:
|
if self.request.user.is_authenticated:
|
||||||
|
new_connection = self.user_connection_type(
|
||||||
|
source=self.source, identifier=self.identifier
|
||||||
|
)
|
||||||
new_connection.user = self.request.user
|
new_connection.user = self.request.user
|
||||||
new_connection = self.update_user_connection(new_connection, **kwargs)
|
new_connection = self.update_user_connection(new_connection, **kwargs)
|
||||||
return Action.LINK, new_connection
|
return Action.LINK, new_connection
|
||||||
|
|
||||||
existing_connections = self.user_connection_type.objects.filter(
|
action, connection = self.matcher.get_user_action(self.identifier, self.user_properties)
|
||||||
source=self.source, identifier=self.identifier
|
if connection:
|
||||||
)
|
connection = self.update_user_connection(connection, **kwargs)
|
||||||
if existing_connections.exists():
|
return action, connection
|
||||||
connection = existing_connections.first()
|
|
||||||
return Action.AUTH, self.update_user_connection(connection, **kwargs)
|
|
||||||
# No connection exists, but we match on identifier, so enroll
|
|
||||||
if self.source.user_matching_mode == SourceUserMatchingModes.IDENTIFIER:
|
|
||||||
# We don't save the connection here cause it doesn't have a user assigned yet
|
|
||||||
return Action.ENROLL, self.update_user_connection(new_connection, **kwargs)
|
|
||||||
|
|
||||||
# Check for existing users with matching attributes
|
|
||||||
query = Q()
|
|
||||||
# Either query existing user based on email or username
|
|
||||||
if self.source.user_matching_mode in [
|
|
||||||
SourceUserMatchingModes.EMAIL_LINK,
|
|
||||||
SourceUserMatchingModes.EMAIL_DENY,
|
|
||||||
]:
|
|
||||||
if not self.user_properties.get("email", None):
|
|
||||||
self._logger.warning("Refusing to use none email")
|
|
||||||
return Action.DENY, None
|
|
||||||
query = Q(email__exact=self.user_properties.get("email", None))
|
|
||||||
if self.source.user_matching_mode in [
|
|
||||||
SourceUserMatchingModes.USERNAME_LINK,
|
|
||||||
SourceUserMatchingModes.USERNAME_DENY,
|
|
||||||
]:
|
|
||||||
if not self.user_properties.get("username", None):
|
|
||||||
self._logger.warning("Refusing to use none username")
|
|
||||||
return Action.DENY, None
|
|
||||||
query = Q(username__exact=self.user_properties.get("username", None))
|
|
||||||
self._logger.debug("trying to link with existing user", query=query)
|
|
||||||
matching_users = User.objects.filter(query)
|
|
||||||
# No matching users, always enroll
|
|
||||||
if not matching_users.exists():
|
|
||||||
self._logger.debug("no matching users found, enrolling")
|
|
||||||
return Action.ENROLL, self.update_user_connection(new_connection, **kwargs)
|
|
||||||
|
|
||||||
user = matching_users.first()
|
|
||||||
if self.source.user_matching_mode in [
|
|
||||||
SourceUserMatchingModes.EMAIL_LINK,
|
|
||||||
SourceUserMatchingModes.USERNAME_LINK,
|
|
||||||
]:
|
|
||||||
new_connection.user = user
|
|
||||||
new_connection = self.update_user_connection(new_connection, **kwargs)
|
|
||||||
return Action.LINK, new_connection
|
|
||||||
if self.source.user_matching_mode in [
|
|
||||||
SourceUserMatchingModes.EMAIL_DENY,
|
|
||||||
SourceUserMatchingModes.USERNAME_DENY,
|
|
||||||
]:
|
|
||||||
self._logger.info("denying source because user exists", user=user)
|
|
||||||
return Action.DENY, None
|
|
||||||
# Should never get here as default enroll case is returned above.
|
|
||||||
return Action.DENY, None # pragma: no cover
|
|
||||||
|
|
||||||
def update_user_connection(
|
def update_user_connection(
|
||||||
self, connection: UserSourceConnection, **kwargs
|
self, connection: UserSourceConnection, **kwargs
|
||||||
@ -408,74 +352,16 @@ class SourceFlowManager:
|
|||||||
class GroupUpdateStage(StageView):
|
class GroupUpdateStage(StageView):
|
||||||
"""Dynamically injected stage which updates the user after enrollment/authentication."""
|
"""Dynamically injected stage which updates the user after enrollment/authentication."""
|
||||||
|
|
||||||
def get_action(
|
|
||||||
self, group_id: str, group_properties: dict[str, Any | dict[str, Any]]
|
|
||||||
) -> tuple[Action, GroupSourceConnection | None]:
|
|
||||||
"""decide which action should be taken"""
|
|
||||||
new_connection = self.group_connection_type(source=self.source, identifier=group_id)
|
|
||||||
|
|
||||||
existing_connections = self.group_connection_type.objects.filter(
|
|
||||||
source=self.source, identifier=group_id
|
|
||||||
)
|
|
||||||
if existing_connections.exists():
|
|
||||||
return Action.LINK, existing_connections.first()
|
|
||||||
# No connection exists, but we match on identifier, so enroll
|
|
||||||
if self.source.group_matching_mode == SourceGroupMatchingModes.IDENTIFIER:
|
|
||||||
# We don't save the connection here cause it doesn't have a user assigned yet
|
|
||||||
return Action.ENROLL, new_connection
|
|
||||||
|
|
||||||
# Check for existing groups with matching attributes
|
|
||||||
query = Q()
|
|
||||||
if self.source.group_matching_mode in [
|
|
||||||
SourceGroupMatchingModes.NAME_LINK,
|
|
||||||
SourceGroupMatchingModes.NAME_DENY,
|
|
||||||
]:
|
|
||||||
if not group_properties.get("name", None):
|
|
||||||
LOGGER.warning(
|
|
||||||
"Refusing to use none group name", source=self.source, group_id=group_id
|
|
||||||
)
|
|
||||||
return Action.DENY, None
|
|
||||||
query = Q(name__exact=group_properties.get("name"))
|
|
||||||
LOGGER.debug(
|
|
||||||
"trying to link with existing group", source=self.source, query=query, group_id=group_id
|
|
||||||
)
|
|
||||||
matching_groups = Group.objects.filter(query)
|
|
||||||
# No matching groups, always enroll
|
|
||||||
if not matching_groups.exists():
|
|
||||||
LOGGER.debug(
|
|
||||||
"no matching groups found, enrolling", source=self.source, group_id=group_id
|
|
||||||
)
|
|
||||||
return Action.ENROLL, new_connection
|
|
||||||
|
|
||||||
group = matching_groups.first()
|
|
||||||
if self.source.group_matching_mode in [
|
|
||||||
SourceGroupMatchingModes.NAME_LINK,
|
|
||||||
]:
|
|
||||||
new_connection.group = group
|
|
||||||
return Action.LINK, new_connection
|
|
||||||
if self.source.group_matching_mode in [
|
|
||||||
SourceGroupMatchingModes.NAME_DENY,
|
|
||||||
]:
|
|
||||||
LOGGER.info(
|
|
||||||
"denying source because group exists",
|
|
||||||
source=self.source,
|
|
||||||
group=group,
|
|
||||||
group_id=group_id,
|
|
||||||
)
|
|
||||||
return Action.DENY, None
|
|
||||||
# Should never get here as default enroll case is returned above.
|
|
||||||
return Action.DENY, None # pragma: no cover
|
|
||||||
|
|
||||||
def handle_group(
|
def handle_group(
|
||||||
self, group_id: str, group_properties: dict[str, Any | dict[str, Any]]
|
self, group_id: str, group_properties: dict[str, Any | dict[str, Any]]
|
||||||
) -> Group | None:
|
) -> Group | None:
|
||||||
action, connection = self.get_action(group_id, group_properties)
|
action, connection = self.matcher.get_group_action(group_id, group_properties)
|
||||||
if action == Action.ENROLL:
|
if action == Action.ENROLL:
|
||||||
group = Group.objects.create(**group_properties)
|
group = Group.objects.create(**group_properties)
|
||||||
connection.group = group
|
connection.group = group
|
||||||
connection.save()
|
connection.save()
|
||||||
return group
|
return group
|
||||||
elif action == Action.LINK:
|
elif action in (Action.LINK, Action.AUTH):
|
||||||
group = connection.group
|
group = connection.group
|
||||||
group.update_attributes(group_properties)
|
group.update_attributes(group_properties)
|
||||||
connection.save()
|
connection.save()
|
||||||
@ -489,6 +375,7 @@ class GroupUpdateStage(StageView):
|
|||||||
self.group_connection_type: GroupSourceConnection = (
|
self.group_connection_type: GroupSourceConnection = (
|
||||||
self.executor.current_stage.group_connection_type
|
self.executor.current_stage.group_connection_type
|
||||||
)
|
)
|
||||||
|
self.matcher = SourceMatcher(self.source, None, self.group_connection_type)
|
||||||
|
|
||||||
raw_groups: dict[str, dict[str, Any | dict[str, Any]]] = self.executor.plan.context[
|
raw_groups: dict[str, dict[str, Any | dict[str, Any]]] = self.executor.plan.context[
|
||||||
PLAN_CONTEXT_SOURCE_GROUPS
|
PLAN_CONTEXT_SOURCE_GROUPS
|
||||||
|
|||||||
152
authentik/core/sources/matcher.py
Normal file
152
authentik/core/sources/matcher.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
"""Source user and group matching"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from django.db.models import Q
|
||||||
|
from structlog import get_logger
|
||||||
|
|
||||||
|
from authentik.core.models import (
|
||||||
|
Group,
|
||||||
|
GroupSourceConnection,
|
||||||
|
Source,
|
||||||
|
SourceGroupMatchingModes,
|
||||||
|
SourceUserMatchingModes,
|
||||||
|
User,
|
||||||
|
UserSourceConnection,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Action(Enum):
|
||||||
|
"""Actions that can be decided based on the request and source settings"""
|
||||||
|
|
||||||
|
LINK = "link"
|
||||||
|
AUTH = "auth"
|
||||||
|
ENROLL = "enroll"
|
||||||
|
DENY = "deny"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MatchableProperty:
|
||||||
|
property: str
|
||||||
|
link_mode: SourceUserMatchingModes | SourceGroupMatchingModes
|
||||||
|
deny_mode: SourceUserMatchingModes | SourceGroupMatchingModes
|
||||||
|
|
||||||
|
|
||||||
|
class SourceMatcher:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
source: Source,
|
||||||
|
user_connection_type: type[UserSourceConnection],
|
||||||
|
group_connection_type: type[GroupSourceConnection],
|
||||||
|
):
|
||||||
|
self.source = source
|
||||||
|
self.user_connection_type = user_connection_type
|
||||||
|
self.group_connection_type = group_connection_type
|
||||||
|
self._logger = get_logger().bind(source=self.source)
|
||||||
|
|
||||||
|
def get_action(
|
||||||
|
self,
|
||||||
|
object_type: type[User | Group],
|
||||||
|
matchable_properties: list[MatchableProperty],
|
||||||
|
identifier: str,
|
||||||
|
properties: dict[str, Any | dict[str, Any]],
|
||||||
|
) -> tuple[Action, UserSourceConnection | GroupSourceConnection | None]:
|
||||||
|
connection_type = None
|
||||||
|
matching_mode = None
|
||||||
|
identifier_matching_mode = None
|
||||||
|
if object_type == User:
|
||||||
|
connection_type = self.user_connection_type
|
||||||
|
matching_mode = self.source.user_matching_mode
|
||||||
|
identifier_matching_mode = SourceUserMatchingModes.IDENTIFIER
|
||||||
|
if object_type == Group:
|
||||||
|
connection_type = self.group_connection_type
|
||||||
|
matching_mode = self.source.group_matching_mode
|
||||||
|
identifier_matching_mode = SourceGroupMatchingModes.IDENTIFIER
|
||||||
|
if not connection_type or not matching_mode or not identifier_matching_mode:
|
||||||
|
return Action.DENY, None
|
||||||
|
|
||||||
|
new_connection = connection_type(source=self.source, identifier=identifier)
|
||||||
|
|
||||||
|
existing_connections = connection_type.objects.filter(
|
||||||
|
source=self.source, identifier=identifier
|
||||||
|
)
|
||||||
|
if existing_connections.exists():
|
||||||
|
return Action.AUTH, existing_connections.first()
|
||||||
|
# No connection exists, but we match on identifier, so enroll
|
||||||
|
if matching_mode == identifier_matching_mode:
|
||||||
|
# We don't save the connection here cause it doesn't have a user/group assigned yet
|
||||||
|
return Action.ENROLL, new_connection
|
||||||
|
|
||||||
|
# Check for existing users with matching attributes
|
||||||
|
query = Q()
|
||||||
|
for matchable_property in matchable_properties:
|
||||||
|
property = matchable_property.property
|
||||||
|
if matching_mode in [matchable_property.link_mode, matchable_property.deny_mode]:
|
||||||
|
if not properties.get(property, None):
|
||||||
|
self._logger.warning(
|
||||||
|
"Refusing to use none property", identifier=identifier, property=property
|
||||||
|
)
|
||||||
|
return Action.DENY, None
|
||||||
|
query_args = {
|
||||||
|
f"{property}__exact": properties[property],
|
||||||
|
}
|
||||||
|
query = Q(**query_args)
|
||||||
|
self._logger.debug(
|
||||||
|
"Trying to link with existing object", query=query, identifier=identifier
|
||||||
|
)
|
||||||
|
matching_objects = object_type.objects.filter(query)
|
||||||
|
# Not matching objects, always enroll
|
||||||
|
if not matching_objects.exists():
|
||||||
|
self._logger.debug("No matching objects found, enrolling")
|
||||||
|
return Action.ENROLL, new_connection
|
||||||
|
|
||||||
|
obj = matching_objects.first()
|
||||||
|
if matching_mode in [mp.link_mode for mp in matchable_properties]:
|
||||||
|
attr = None
|
||||||
|
if object_type == User:
|
||||||
|
attr = "user"
|
||||||
|
if object_type == Group:
|
||||||
|
attr = "group"
|
||||||
|
setattr(new_connection, attr, obj)
|
||||||
|
return Action.LINK, new_connection
|
||||||
|
if matching_mode in [mp.deny_mode for mp in matchable_properties]:
|
||||||
|
self._logger.info("Denying source because object exists", obj=obj)
|
||||||
|
return Action.DENY, None
|
||||||
|
|
||||||
|
# Should never get here as default enroll case is returned above.
|
||||||
|
return Action.DENY, None # pragma: no cover
|
||||||
|
|
||||||
|
def get_user_action(
|
||||||
|
self, identifier: str, properties: dict[str, Any | dict[str, Any]]
|
||||||
|
) -> tuple[Action, UserSourceConnection | None]:
|
||||||
|
return self.get_action(
|
||||||
|
User,
|
||||||
|
[
|
||||||
|
MatchableProperty(
|
||||||
|
"username",
|
||||||
|
SourceUserMatchingModes.USERNAME_LINK,
|
||||||
|
SourceUserMatchingModes.USERNAME_DENY,
|
||||||
|
),
|
||||||
|
MatchableProperty(
|
||||||
|
"email", SourceUserMatchingModes.EMAIL_LINK, SourceUserMatchingModes.EMAIL_DENY
|
||||||
|
),
|
||||||
|
],
|
||||||
|
identifier,
|
||||||
|
properties,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_group_action(
|
||||||
|
self, identifier: str, properties: dict[str, Any | dict[str, Any]]
|
||||||
|
) -> tuple[Action, GroupSourceConnection | None]:
|
||||||
|
return self.get_action(
|
||||||
|
Group,
|
||||||
|
[
|
||||||
|
MatchableProperty(
|
||||||
|
"name", SourceGroupMatchingModes.NAME_LINK, SourceGroupMatchingModes.NAME_DENY
|
||||||
|
),
|
||||||
|
],
|
||||||
|
identifier,
|
||||||
|
properties,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user