outgoing sync dispatch tasks (no magic)

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-25 14:54:31 +02:00
parent 97a5acdff5
commit 13b5aa604b
9 changed files with 123 additions and 62 deletions

View File

@ -2,13 +2,13 @@
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
from authentik.enterprise.providers.google_workspace.tasks import ( from authentik.enterprise.providers.google_workspace.tasks import (
google_workspace_sync_direct, google_workspace_sync_direct_dispatch,
google_workspace_sync_m2m, google_workspace_sync_m2m_dispatch,
) )
from authentik.lib.sync.outgoing.signals import register_signals from authentik.lib.sync.outgoing.signals import register_signals
register_signals( register_signals(
GoogleWorkspaceProvider, GoogleWorkspaceProvider,
task_sync_direct=google_workspace_sync_direct, task_sync_direct_dispatch=google_workspace_sync_direct_dispatch,
task_sync_m2m=google_workspace_sync_m2m, task_sync_m2m_dispatch=google_workspace_sync_m2m_dispatch,
) )

View File

@ -25,6 +25,24 @@ def google_workspace_sync_direct(*args, **kwargs):
return sync_tasks.sync_signal_direct(*args, **kwargs) return sync_tasks.sync_signal_direct(*args, **kwargs)
@actor(
description=_(
"Dispatch syncs for a direct object (user, group) for Google Workspace providers."
)
)
def google_workspace_sync_direct_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_direct_dispatch(google_workspace_sync_direct, *args, **kwargs)
@actor(description=_("Sync a related object (memberships) for Google Workspace provider.")) @actor(description=_("Sync a related object (memberships) for Google Workspace provider."))
def google_workspace_sync_m2m(*args, **kwargs): def google_workspace_sync_m2m(*args, **kwargs):
return sync_tasks.sync_signal_m2m(*args, **kwargs) return sync_tasks.sync_signal_m2m(*args, **kwargs)
@actor(
description=_(
"Dispatch syncs for a related object (memberships) for Google Workspace providers."
)
)
def google_workspace_sync_m2m_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_m2m_dispatch(google_workspace_sync_m2m, *args, **kwargs)

View File

@ -2,13 +2,13 @@
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
from authentik.enterprise.providers.microsoft_entra.tasks import ( from authentik.enterprise.providers.microsoft_entra.tasks import (
microsoft_entra_sync_direct, microsoft_entra_sync_direct_dispatch,
microsoft_entra_sync_m2m, microsoft_entra_sync_m2m_dispatch,
) )
from authentik.lib.sync.outgoing.signals import register_signals from authentik.lib.sync.outgoing.signals import register_signals
register_signals( register_signals(
MicrosoftEntraProvider, MicrosoftEntraProvider,
task_sync_direct=microsoft_entra_sync_direct, task_sync_direct_dispatch=microsoft_entra_sync_direct_dispatch,
task_sync_m2m=microsoft_entra_sync_m2m, task_sync_m2m_dispatch=microsoft_entra_sync_m2m_dispatch,
) )

View File

@ -25,6 +25,22 @@ def microsoft_entra_sync_direct(*args, **kwargs):
return sync_tasks.sync_signal_direct(*args, **kwargs) return sync_tasks.sync_signal_direct(*args, **kwargs)
@actor(
description=_("Dispatch syncs for a direct object (user, group) for Microsoft Entra providers.")
)
def microsoft_entra_sync_direct_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_direct_dispatch(microsoft_entra_sync_direct, *args, **kwargs)
@actor(description=_("Sync a related object (memberships) for Microsoft Entra provider.")) @actor(description=_("Sync a related object (memberships) for Microsoft Entra provider."))
def microsoft_entra_sync_m2m(*args, **kwargs): def microsoft_entra_sync_m2m(*args, **kwargs):
return sync_tasks.sync_signal_m2m(*args, **kwargs) return sync_tasks.sync_signal_m2m(*args, **kwargs)
@actor(
description=_(
"Dispatch syncs for a related object (memberships) for Microsoft Entra providers."
)
)
def microsoft_entra_sync_m2m_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_m2m_dispatch(microsoft_entra_sync_m2m, *args, **kwargs)

View File

@ -7,7 +7,7 @@ from rest_framework.response import Response
from authentik.core.api.utils import ModelSerializer, PassiveSerializer from authentik.core.api.utils import ModelSerializer, PassiveSerializer
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.events.logs import LogEvent, LogEventSerializer from authentik.events.logs import LogEventSerializer
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path from authentik.lib.utils.reflection import class_to_path
from authentik.rbac.filters import ObjectFilter from authentik.rbac.filters import ObjectFilter
@ -36,7 +36,7 @@ class SyncObjectResultSerializer(PassiveSerializer):
class OutgoingSyncProviderStatusMixin: class OutgoingSyncProviderStatusMixin:
"""Common API Endpoints for Outgoing sync providers""" """Common API Endpoints for Outgoing sync providers"""
sync_objects_task: type[Actor] = None sync_objects_task: Actor
@extend_schema( @extend_schema(
request=SyncObjectSerializer, request=SyncObjectSerializer,
@ -55,12 +55,12 @@ class OutgoingSyncProviderStatusMixin:
params = SyncObjectSerializer(data=request.data) params = SyncObjectSerializer(data=request.data)
params.is_valid(raise_exception=True) params.is_valid(raise_exception=True)
msg = self.sync_objects_task.send_with_options( msg = self.sync_objects_task.send_with_options(
args=(params.validated_data["sync_object_model"],),
kwargs={ kwargs={
"object_type": params.validated_data["sync_object_model"],
"page": 1, "page": 1,
"provider_pk": provider.pk, "provider_pk": provider.pk,
"pk": params.validated_data["sync_object_id"],
"override_dry_run": params.validated_data["override_dry_run"], "override_dry_run": params.validated_data["override_dry_run"],
"pk": params.validated_data["sync_object_id"],
}, },
rel_obj=provider, rel_obj=provider,
) )

View File

@ -11,44 +11,30 @@ from authentik.lib.utils.reflection import class_to_path
def register_signals( def register_signals(
provider_type: type[OutgoingSyncProvider], provider_type: type[OutgoingSyncProvider],
task_sync_direct: Actor, task_sync_direct_dispatch: Actor[[str, str | int, str], None],
task_sync_m2m: Actor, task_sync_m2m_dispatch: Actor[[str, str, list[str], bool], None],
): ):
"""Register sync signals""" """Register sync signals"""
uid = class_to_path(provider_type) uid = class_to_path(provider_type)
def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_): def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_):
"""Post save handler""" """Post save handler"""
for provider in provider_type.objects.filter( task_sync_direct_dispatch.send(
Q(backchannel_application__isnull=False) | Q(application__isnull=False) class_to_path(instance.__class__),
): instance.pk,
task_sync_direct.send_with_options( Direction.add.value,
args=( )
class_to_path(instance.__class__),
instance.pk,
provider.pk,
Direction.add.value,
),
rel_obj=provider,
)
post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False) post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False)
post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False) post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False)
def model_pre_delete(sender: type[Model], instance: User | Group, **_): def model_pre_delete(sender: type[Model], instance: User | Group, **_):
"""Pre-delete handler""" """Pre-delete handler"""
for provider in provider_type.objects.filter( task_sync_direct_dispatch.send(
Q(backchannel_application__isnull=False) | Q(application__isnull=False) class_to_path(instance.__class__),
): instance.pk,
task_sync_direct.send_with_options( Direction.remove.value,
args=( )
class_to_path(instance.__class__),
instance.pk,
provider.pk,
Direction.remove.value,
),
rel_obj=provider,
)
pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False) pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False)
pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False) pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False)
@ -59,21 +45,6 @@ def register_signals(
"""Sync group membership""" """Sync group membership"""
if action not in ["post_add", "post_remove"]: if action not in ["post_add", "post_remove"]:
return return
for provider in provider_type.objects.filter( task_sync_m2m_dispatch.send(instance.pk, action, list(pk_set), reverse)
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
# reverse: instance is a Group, pk_set is a list of user pks
# non-reverse: instance is a User, pk_set is a list of groups
if reverse:
task_sync_m2m.send_with_options(
args=(instance.pk, provider.pk, action, list(pk_set)),
rel_obj=provider,
)
else:
for group_pk in pk_set:
task_sync_m2m.send_with_options(
args=(group_pk, provider.pk, action, [instance.pk]),
rel_obj=provider,
)
m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False) m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False)

View File

@ -38,7 +38,7 @@ class SyncTasks:
self, self,
current_task: Task, current_task: Task,
provider: OutgoingSyncProvider, provider: OutgoingSyncProvider,
sync_objects: Actor, sync_objects: Actor[[str, int, int, bool], None],
paginator: Paginator, paginator: Paginator,
object_type: type[User | Group], object_type: type[User | Group],
**options, **options,
@ -58,7 +58,7 @@ class SyncTasks:
def sync( def sync(
self, self,
provider_pk: int, provider_pk: int,
sync_objects: Actor, sync_objects: Actor[[str, int, int, bool], None],
): ):
task: Task = CurrentTask.get_task() task: Task = CurrentTask.get_task()
self.logger = get_logger().bind( self.logger = get_logger().bind(
@ -101,6 +101,7 @@ class SyncTasks:
group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group)) group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group))
except TransientSyncException as exc: except TransientSyncException as exc:
self.logger.warning("transient sync exception", exc=exc) self.logger.warning("transient sync exception", exc=exc)
task.warning("Sync encountered a transient exception. Retrying", exc=exc)
raise Retry() from exc raise Retry() from exc
except StopSync as exc: except StopSync as exc:
task.error(exc) task.error(exc)
@ -162,23 +163,39 @@ class SyncTasks:
self.logger.warning("failed to sync object", exc=exc, obj=obj) self.logger.warning("failed to sync object", exc=exc, obj=obj)
task.warning( task.warning(
f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to error: {str(exc)}", f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to error: {str(exc)}",
attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)}, arguments=exc.args[1:],
obj=sanitize_item(obj),
) )
except TransientSyncException as exc: except TransientSyncException as exc:
self.logger.warning("failed to sync object", exc=exc, user=obj) self.logger.warning("failed to sync object", exc=exc, user=obj)
task.warning( task.warning(
f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to " f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to "
"transient error: {str(exc)}", "transient error: {str(exc)}",
attributes={"obj": sanitize_item(obj)}, obj=sanitize_item(obj),
) )
except StopSync as exc: except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc) self.logger.warning("Stopping sync", exc=exc)
task.warning( task.warning(
f"Stopping sync due to error: {exc.detail()}", f"Stopping sync due to error: {exc.detail()}",
attributes={"obj": sanitize_item(obj)}, obj=sanitize_item(obj),
) )
break break
def sync_signal_direct_dispatch(
self,
task_sync_signal_direct: Actor[[str, str | int, int, str], None],
model: str,
pk: str | int,
raw_op: str,
):
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
task_sync_signal_direct.send_with_options(
args=(model, pk, provider.pk, raw_op),
rel_obj=provider,
)
def sync_signal_direct( def sync_signal_direct(
self, self,
model: str, model: str,
@ -227,6 +244,35 @@ class SyncTasks:
except StopSync as exc: except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
def sync_signal_m2m_dispatch(
self,
task_sync_signal_m2m: Actor[[str, int, str, list[int]], None],
instance_pk: str,
action: str,
pk_set: list[int],
reverse: bool,
):
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
task_sync_signal_m2m.send_with_options(
args=(instance_pk, provider.pk, action, pk_set),
rel_obj=provider,
)
# reverse: instance is a Group, pk_set is a list of user pks
# non-reverse: instance is a User, pk_set is a list of groups
if reverse:
task_sync_signal_m2m.send_with_options(
args=(instance_pk, provider.pk, action, list(pk_set)),
rel_obj=provider,
)
else:
for pk in pk_set:
task_sync_signal_m2m.send_with_options(
args=(pk, provider.pk, action, [instance_pk]),
rel_obj=provider,
)
def sync_signal_m2m( def sync_signal_m2m(
self, self,
group_pk: str, group_pk: str,

View File

@ -2,10 +2,10 @@
from authentik.lib.sync.outgoing.signals import register_signals from authentik.lib.sync.outgoing.signals import register_signals
from authentik.providers.scim.models import SCIMProvider from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_sync_direct, scim_sync_m2m from authentik.providers.scim.tasks import scim_sync_direct_dispatch, scim_sync_m2m_dispatch
register_signals( register_signals(
SCIMProvider, SCIMProvider,
task_sync_direct=scim_sync_direct, task_sync_direct_dispatch=scim_sync_direct_dispatch,
task_sync_m2m=scim_sync_m2m, task_sync_m2m_dispatch=scim_sync_m2m_dispatch,
) )

View File

@ -25,6 +25,16 @@ def scim_sync_direct(*args, **kwargs):
return sync_tasks.sync_signal_direct(*args, **kwargs) return sync_tasks.sync_signal_direct(*args, **kwargs)
@actor(description=_("Dispatch syncs for a direct object (user, group) for SCIM providers."))
def scim_sync_direct_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_direct_dispatch(scim_sync_direct, *args, **kwargs)
@actor(description=_("Sync a related object (memberships) for SCIM provider.")) @actor(description=_("Sync a related object (memberships) for SCIM provider."))
def scim_sync_m2m(*args, **kwargs): def scim_sync_m2m(*args, **kwargs):
return sync_tasks.sync_signal_m2m(*args, **kwargs) return sync_tasks.sync_signal_m2m(*args, **kwargs)
@actor(description=_("Dispatch syncs for a related object (memberships) for SCIM providers."))
def scim_sync_m2m_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_m2m_dispatch(scim_sync_m2m, *args, **kwargs)