outgoing sync dispatch tasks (no magic)

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-25 14:54:31 +02:00
parent 97a5acdff5
commit 13b5aa604b
9 changed files with 123 additions and 62 deletions

View File

@ -2,13 +2,13 @@
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
from authentik.enterprise.providers.google_workspace.tasks import (
google_workspace_sync_direct,
google_workspace_sync_m2m,
google_workspace_sync_direct_dispatch,
google_workspace_sync_m2m_dispatch,
)
from authentik.lib.sync.outgoing.signals import register_signals
register_signals(
GoogleWorkspaceProvider,
task_sync_direct=google_workspace_sync_direct,
task_sync_m2m=google_workspace_sync_m2m,
task_sync_direct_dispatch=google_workspace_sync_direct_dispatch,
task_sync_m2m_dispatch=google_workspace_sync_m2m_dispatch,
)

View File

@ -25,6 +25,24 @@ def google_workspace_sync_direct(*args, **kwargs):
return sync_tasks.sync_signal_direct(*args, **kwargs)
@actor(
description=_(
"Dispatch syncs for a direct object (user, group) for Google Workspace providers."
)
)
def google_workspace_sync_direct_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_direct_dispatch(google_workspace_sync_direct, *args, **kwargs)
@actor(description=_("Sync a related object (memberships) for Google Workspace provider."))
def google_workspace_sync_m2m(*args, **kwargs):
return sync_tasks.sync_signal_m2m(*args, **kwargs)
@actor(
description=_(
"Dispatch syncs for a related object (memberships) for Google Workspace providers."
)
)
def google_workspace_sync_m2m_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_m2m_dispatch(google_workspace_sync_m2m, *args, **kwargs)

View File

@ -2,13 +2,13 @@
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
from authentik.enterprise.providers.microsoft_entra.tasks import (
microsoft_entra_sync_direct,
microsoft_entra_sync_m2m,
microsoft_entra_sync_direct_dispatch,
microsoft_entra_sync_m2m_dispatch,
)
from authentik.lib.sync.outgoing.signals import register_signals
register_signals(
MicrosoftEntraProvider,
task_sync_direct=microsoft_entra_sync_direct,
task_sync_m2m=microsoft_entra_sync_m2m,
task_sync_direct_dispatch=microsoft_entra_sync_direct_dispatch,
task_sync_m2m_dispatch=microsoft_entra_sync_m2m_dispatch,
)

View File

@ -25,6 +25,22 @@ def microsoft_entra_sync_direct(*args, **kwargs):
return sync_tasks.sync_signal_direct(*args, **kwargs)
@actor(
description=_("Dispatch syncs for a direct object (user, group) for Microsoft Entra providers.")
)
def microsoft_entra_sync_direct_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_direct_dispatch(microsoft_entra_sync_direct, *args, **kwargs)
@actor(description=_("Sync a related object (memberships) for Microsoft Entra provider."))
def microsoft_entra_sync_m2m(*args, **kwargs):
return sync_tasks.sync_signal_m2m(*args, **kwargs)
@actor(
description=_(
"Dispatch syncs for a related object (memberships) for Microsoft Entra providers."
)
)
def microsoft_entra_sync_m2m_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_m2m_dispatch(microsoft_entra_sync_m2m, *args, **kwargs)

View File

@ -7,7 +7,7 @@ from rest_framework.response import Response
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
from authentik.core.models import Group, User
from authentik.events.logs import LogEvent, LogEventSerializer
from authentik.events.logs import LogEventSerializer
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path
from authentik.rbac.filters import ObjectFilter
@ -36,7 +36,7 @@ class SyncObjectResultSerializer(PassiveSerializer):
class OutgoingSyncProviderStatusMixin:
"""Common API Endpoints for Outgoing sync providers"""
sync_objects_task: type[Actor] = None
sync_objects_task: Actor
@extend_schema(
request=SyncObjectSerializer,
@ -55,12 +55,12 @@ class OutgoingSyncProviderStatusMixin:
params = SyncObjectSerializer(data=request.data)
params.is_valid(raise_exception=True)
msg = self.sync_objects_task.send_with_options(
args=(params.validated_data["sync_object_model"],),
kwargs={
"object_type": params.validated_data["sync_object_model"],
"page": 1,
"provider_pk": provider.pk,
"pk": params.validated_data["sync_object_id"],
"override_dry_run": params.validated_data["override_dry_run"],
"pk": params.validated_data["sync_object_id"],
},
rel_obj=provider,
)

View File

@ -11,44 +11,30 @@ from authentik.lib.utils.reflection import class_to_path
def register_signals(
provider_type: type[OutgoingSyncProvider],
task_sync_direct: Actor,
task_sync_m2m: Actor,
task_sync_direct_dispatch: Actor[[str, str | int, str], None],
task_sync_m2m_dispatch: Actor[[str, str, list[str], bool], None],
):
"""Register sync signals"""
uid = class_to_path(provider_type)
def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_):
"""Post save handler"""
for provider in provider_type.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
task_sync_direct.send_with_options(
args=(
class_to_path(instance.__class__),
instance.pk,
provider.pk,
Direction.add.value,
),
rel_obj=provider,
)
task_sync_direct_dispatch.send(
class_to_path(instance.__class__),
instance.pk,
Direction.add.value,
)
post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False)
post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False)
def model_pre_delete(sender: type[Model], instance: User | Group, **_):
"""Pre-delete handler"""
for provider in provider_type.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
task_sync_direct.send_with_options(
args=(
class_to_path(instance.__class__),
instance.pk,
provider.pk,
Direction.remove.value,
),
rel_obj=provider,
)
task_sync_direct_dispatch.send(
class_to_path(instance.__class__),
instance.pk,
Direction.remove.value,
)
pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False)
pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False)
@ -59,21 +45,6 @@ def register_signals(
"""Sync group membership"""
if action not in ["post_add", "post_remove"]:
return
for provider in provider_type.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
# reverse: instance is a Group, pk_set is a list of user pks
# non-reverse: instance is a User, pk_set is a list of groups
if reverse:
task_sync_m2m.send_with_options(
args=(instance.pk, provider.pk, action, list(pk_set)),
rel_obj=provider,
)
else:
for group_pk in pk_set:
task_sync_m2m.send_with_options(
args=(group_pk, provider.pk, action, [instance.pk]),
rel_obj=provider,
)
task_sync_m2m_dispatch.send(instance.pk, action, list(pk_set), reverse)
m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False)

View File

@ -38,7 +38,7 @@ class SyncTasks:
self,
current_task: Task,
provider: OutgoingSyncProvider,
sync_objects: Actor,
sync_objects: Actor[[str, int, int, bool], None],
paginator: Paginator,
object_type: type[User | Group],
**options,
@ -58,7 +58,7 @@ class SyncTasks:
def sync(
self,
provider_pk: int,
sync_objects: Actor,
sync_objects: Actor[[str, int, int, bool], None],
):
task: Task = CurrentTask.get_task()
self.logger = get_logger().bind(
@ -101,6 +101,7 @@ class SyncTasks:
group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group))
except TransientSyncException as exc:
self.logger.warning("transient sync exception", exc=exc)
task.warning("Sync encountered a transient exception. Retrying", exc=exc)
raise Retry() from exc
except StopSync as exc:
task.error(exc)
@ -162,23 +163,39 @@ class SyncTasks:
self.logger.warning("failed to sync object", exc=exc, obj=obj)
task.warning(
f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to error: {str(exc)}",
attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)},
arguments=exc.args[1:],
obj=sanitize_item(obj),
)
except TransientSyncException as exc:
self.logger.warning("failed to sync object", exc=exc, user=obj)
task.warning(
f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to "
"transient error: {str(exc)}",
attributes={"obj": sanitize_item(obj)},
obj=sanitize_item(obj),
)
except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc)
task.warning(
f"Stopping sync due to error: {exc.detail()}",
attributes={"obj": sanitize_item(obj)},
obj=sanitize_item(obj),
)
break
def sync_signal_direct_dispatch(
self,
task_sync_signal_direct: Actor[[str, str | int, int, str], None],
model: str,
pk: str | int,
raw_op: str,
):
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
task_sync_signal_direct.send_with_options(
args=(model, pk, provider.pk, raw_op),
rel_obj=provider,
)
def sync_signal_direct(
self,
model: str,
@ -227,6 +244,35 @@ class SyncTasks:
except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
def sync_signal_m2m_dispatch(
self,
task_sync_signal_m2m: Actor[[str, int, str, list[int]], None],
instance_pk: str,
action: str,
pk_set: list[int],
reverse: bool,
):
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
task_sync_signal_m2m.send_with_options(
args=(instance_pk, provider.pk, action, pk_set),
rel_obj=provider,
)
# reverse: instance is a Group, pk_set is a list of user pks
# non-reverse: instance is a User, pk_set is a list of groups
if reverse:
task_sync_signal_m2m.send_with_options(
args=(instance_pk, provider.pk, action, list(pk_set)),
rel_obj=provider,
)
else:
for pk in pk_set:
task_sync_signal_m2m.send_with_options(
args=(pk, provider.pk, action, [instance_pk]),
rel_obj=provider,
)
def sync_signal_m2m(
self,
group_pk: str,

View File

@ -2,10 +2,10 @@
from authentik.lib.sync.outgoing.signals import register_signals
from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_sync_direct, scim_sync_m2m
from authentik.providers.scim.tasks import scim_sync_direct_dispatch, scim_sync_m2m_dispatch
register_signals(
SCIMProvider,
task_sync_direct=scim_sync_direct,
task_sync_m2m=scim_sync_m2m,
task_sync_direct_dispatch=scim_sync_direct_dispatch,
task_sync_m2m_dispatch=scim_sync_m2m_dispatch,
)

View File

@ -25,6 +25,16 @@ def scim_sync_direct(*args, **kwargs):
return sync_tasks.sync_signal_direct(*args, **kwargs)
@actor(description=_("Dispatch syncs for a direct object (user, group) for SCIM providers."))
def scim_sync_direct_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_direct_dispatch(scim_sync_direct, *args, **kwargs)
@actor(description=_("Sync a related object (memberships) for SCIM provider."))
def scim_sync_m2m(*args, **kwargs):
return sync_tasks.sync_signal_m2m(*args, **kwargs)
@actor(description=_("Dispatch syncs for a related object (memberships) for SCIM providers."))
def scim_sync_m2m_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_m2m_dispatch(scim_sync_m2m, *args, **kwargs)