From 338da726227f0e2b9efc4a226bb9d4c78e1b9881 Mon Sep 17 00:00:00 2001 From: Marc 'risson' Schmitt Date: Thu, 5 Jun 2025 18:32:04 +0200 Subject: [PATCH] wip Signed-off-by: Marc 'risson' Schmitt --- authentik/lib/sync/outgoing/signals.py | 60 +++++++----- authentik/lib/sync/outgoing/tasks.py | 128 +++++++++++++++---------- 2 files changed, 112 insertions(+), 76 deletions(-) diff --git a/authentik/lib/sync/outgoing/signals.py b/authentik/lib/sync/outgoing/signals.py index cd56c07b62..f6df70154a 100644 --- a/authentik/lib/sync/outgoing/signals.py +++ b/authentik/lib/sync/outgoing/signals.py @@ -20,27 +20,36 @@ def register_signals( def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_): """Post save handler""" - if not provider_type.objects.filter( + for provider in provider_type.objects.filter( Q(backchannel_application__isnull=False) | Q(application__isnull=False) - ).exists(): - return - task_sync_direct.send(class_to_path(instance.__class__), instance.pk, Direction.add.value) + ): + task_sync_direct.send_with_options( + args=( + class_to_path(instance.__class__), + instance.pk, + provider.pk, + Direction.add.value, + ), + rel_obj=provider, + ) post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False) post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False) def model_pre_delete(sender: type[Model], instance: User | Group, **_): """Pre-delete handler""" - if not provider_type.objects.filter( + for provider in provider_type.objects.filter( Q(backchannel_application__isnull=False) | Q(application__isnull=False) - ).exists(): - return - try: - task_sync_direct.send( - class_to_path(instance.__class__), instance.pk, Direction.remove.value - ).get_result(block=True) - except ResultFailure: - pass + ): + task_sync_direct.send_with_options( + args=( + class_to_path(instance.__class__), + instance.pk, + provider.pk, + Direction.remove.value, + ), + rel_obj=provider, + ) pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False) pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False) @@ -51,16 +60,21 @@ def register_signals( """Sync group membership""" if action not in ["post_add", "post_remove"]: return - if not provider_type.objects.filter( + for provider in provider_type.objects.filter( Q(backchannel_application__isnull=False) | Q(application__isnull=False) - ).exists(): - return - # reverse: instance is a Group, pk_set is a list of user pks - # non-reverse: instance is a User, pk_set is a list of groups - if reverse: - task_sync_m2m.send(instance.pk, action, list(pk_set)) - else: - for group_pk in pk_set: - task_sync_m2m.send(group_pk, action, [instance.pk]) + ): + # reverse: instance is a Group, pk_set is a list of user pks + # non-reverse: instance is a User, pk_set is a list of groups + if reverse: + task_sync_m2m.send_with_options( + args=(instance.pk, provider.pk, action, list(pk_set)), + rel_obj=provider, + ) + else: + for group_pk in pk_set: + task_sync_m2m.send_with_options( + args=(group_pk, provider.pk, action, [instance.pk]), + rel_obj=provider, + ) m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False) diff --git a/authentik/lib/sync/outgoing/tasks.py b/authentik/lib/sync/outgoing/tasks.py index 12c683481a..49b9411b3e 100644 --- a/authentik/lib/sync/outgoing/tasks.py +++ b/authentik/lib/sync/outgoing/tasks.py @@ -122,7 +122,10 @@ class SyncTasks: object_type=object_type, ) task.info(f"Syncing page {page} of {_object_type._meta.verbose_name_plural}") - provider = self._provider_model.objects.filter(pk=provider_pk).first() + provider: OutgoingSyncProvider = self._provider_model.objects.filter( + Q(backchannel_application__isnull=False) | Q(application__isnull=False), + pk=provider_pk, + ).first() if not provider: return # Override dry run mode if requested, however don't save the provider @@ -177,7 +180,13 @@ class SyncTasks: ) break - def sync_signal_direct(self, model: str, pk: str | int, raw_op: str): + def sync_signal_direct( + self, + model: str, + pk: str | int, + provider_pk: int, + raw_op: str, + ): self.logger = get_logger().bind( provider_type=class_to_path(self._provider_model), ) @@ -185,65 +194,78 @@ class SyncTasks: instance = model_class.objects.filter(pk=pk).first() if not instance: return + provider: OutgoingSyncProvider = self._provider_model.objects.filter( + Q(backchannel_application__isnull=False) | Q(application__isnull=False), + pk=provider_pk, + ).first() + if not provider: + return operation = Direction(raw_op) - for provider in self._provider_model.objects.filter( - Q(backchannel_application__isnull=False) | Q(application__isnull=False) - ): - client = provider.client_for_model(instance.__class__) - # Check if the object is allowed within the provider's restrictions - queryset = provider.get_object_qs(instance.__class__) - if not queryset: - continue + client = provider.client_for_model(instance.__class__) + # Check if the object is allowed within the provider's restrictions + queryset = provider.get_object_qs(instance.__class__) + if not queryset: + return - # The queryset we get from the provider must include the instance we've got given - # otherwise ignore this provider - if not queryset.filter(pk=instance.pk).exists(): - continue + # The queryset we get from the provider must include the instance we've got given + # otherwise ignore this provider + if not queryset.filter(pk=instance.pk).exists(): + return - try: - if operation == Direction.add: - client.write(instance) - if operation == Direction.remove: - client.delete(instance) - except TransientSyncException as exc: - raise Retry() from exc - except SkipObjectException: - continue - except DryRunRejected as exc: - self.logger.info("Rejected dry-run event", exc=exc) - except StopSync as exc: - self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) + try: + if operation == Direction.add: + client.write(instance) + if operation == Direction.remove: + client.delete(instance) + except TransientSyncException as exc: + raise Retry() from exc + except SkipObjectException: + return + except DryRunRejected as exc: + self.logger.info("Rejected dry-run event", exc=exc) + except StopSync as exc: + self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) - def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]): + def sync_signal_m2m( + self, + group_pk: str, + provider_pk: int, + action: str, + pk_set: list[int], + ): self.logger = get_logger().bind( provider_type=class_to_path(self._provider_model), ) group = Group.objects.filter(pk=group_pk).first() if not group: return - for provider in self._provider_model.objects.filter( - Q(backchannel_application__isnull=False) | Q(application__isnull=False) - ): - # Check if the object is allowed within the provider's restrictions - queryset: QuerySet = provider.get_object_qs(Group) - # The queryset we get from the provider must include the instance we've got given - # otherwise ignore this provider - if not queryset.filter(pk=group_pk).exists(): - continue + provider: OutgoingSyncProvider = self._provider_model.objects.filter( + Q(backchannel_application__isnull=False) | Q(application__isnull=False), + pk=provider_pk, + ).first() + if not provider: + return - client = provider.client_for_model(Group) - try: - operation = None - if action == "post_add": - operation = Direction.add - if action == "post_remove": - operation = Direction.remove - client.update_group(group, operation, pk_set) - except TransientSyncException as exc: - raise Retry() from exc - except SkipObjectException: - continue - except DryRunRejected as exc: - self.logger.info("Rejected dry-run event", exc=exc) - except StopSync as exc: - self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) + # Check if the object is allowed within the provider's restrictions + queryset: QuerySet = provider.get_object_qs(Group) + # The queryset we get from the provider must include the instance we've got given + # otherwise ignore this provider + if not queryset.filter(pk=group_pk).exists(): + return + + client = provider.client_for_model(Group) + try: + operation = None + if action == "post_add": + operation = Direction.add + if action == "post_remove": + operation = Direction.remove + client.update_group(group, operation, pk_set) + except TransientSyncException as exc: + raise Retry() from exc + except SkipObjectException: + return + except DryRunRejected as exc: + self.logger.info("Rejected dry-run event", exc=exc) + except StopSync as exc: + self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)