Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-05 18:32:04 +02:00
parent 90debcdd70
commit 338da72622
2 changed files with 112 additions and 76 deletions

View File

@ -20,27 +20,36 @@ def register_signals(
def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_): def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_):
"""Post save handler""" """Post save handler"""
if not provider_type.objects.filter( for provider in provider_type.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False) Q(backchannel_application__isnull=False) | Q(application__isnull=False)
).exists(): ):
return task_sync_direct.send_with_options(
task_sync_direct.send(class_to_path(instance.__class__), instance.pk, Direction.add.value) args=(
class_to_path(instance.__class__),
instance.pk,
provider.pk,
Direction.add.value,
),
rel_obj=provider,
)
post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False) post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False)
post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False) post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False)
def model_pre_delete(sender: type[Model], instance: User | Group, **_): def model_pre_delete(sender: type[Model], instance: User | Group, **_):
"""Pre-delete handler""" """Pre-delete handler"""
if not provider_type.objects.filter( for provider in provider_type.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False) Q(backchannel_application__isnull=False) | Q(application__isnull=False)
).exists(): ):
return task_sync_direct.send_with_options(
try: args=(
task_sync_direct.send( class_to_path(instance.__class__),
class_to_path(instance.__class__), instance.pk, Direction.remove.value instance.pk,
).get_result(block=True) provider.pk,
except ResultFailure: Direction.remove.value,
pass ),
rel_obj=provider,
)
pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False) pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False)
pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False) pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False)
@ -51,16 +60,21 @@ def register_signals(
"""Sync group membership""" """Sync group membership"""
if action not in ["post_add", "post_remove"]: if action not in ["post_add", "post_remove"]:
return return
if not provider_type.objects.filter( for provider in provider_type.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False) Q(backchannel_application__isnull=False) | Q(application__isnull=False)
).exists(): ):
return
# reverse: instance is a Group, pk_set is a list of user pks # reverse: instance is a Group, pk_set is a list of user pks
# non-reverse: instance is a User, pk_set is a list of groups # non-reverse: instance is a User, pk_set is a list of groups
if reverse: if reverse:
task_sync_m2m.send(instance.pk, action, list(pk_set)) task_sync_m2m.send_with_options(
args=(instance.pk, provider.pk, action, list(pk_set)),
rel_obj=provider,
)
else: else:
for group_pk in pk_set: for group_pk in pk_set:
task_sync_m2m.send(group_pk, action, [instance.pk]) task_sync_m2m.send_with_options(
args=(group_pk, provider.pk, action, [instance.pk]),
rel_obj=provider,
)
m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False) m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False)

View File

@ -122,7 +122,10 @@ class SyncTasks:
object_type=object_type, object_type=object_type,
) )
task.info(f"Syncing page {page} of {_object_type._meta.verbose_name_plural}") task.info(f"Syncing page {page} of {_object_type._meta.verbose_name_plural}")
provider = self._provider_model.objects.filter(pk=provider_pk).first() provider: OutgoingSyncProvider = self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
pk=provider_pk,
).first()
if not provider: if not provider:
return return
# Override dry run mode if requested, however don't save the provider # Override dry run mode if requested, however don't save the provider
@ -177,7 +180,13 @@ class SyncTasks:
) )
break break
def sync_signal_direct(self, model: str, pk: str | int, raw_op: str): def sync_signal_direct(
self,
model: str,
pk: str | int,
provider_pk: int,
raw_op: str,
):
self.logger = get_logger().bind( self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model), provider_type=class_to_path(self._provider_model),
) )
@ -185,20 +194,23 @@ class SyncTasks:
instance = model_class.objects.filter(pk=pk).first() instance = model_class.objects.filter(pk=pk).first()
if not instance: if not instance:
return return
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
pk=provider_pk,
).first()
if not provider:
return
operation = Direction(raw_op) operation = Direction(raw_op)
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
client = provider.client_for_model(instance.__class__) client = provider.client_for_model(instance.__class__)
# Check if the object is allowed within the provider's restrictions # Check if the object is allowed within the provider's restrictions
queryset = provider.get_object_qs(instance.__class__) queryset = provider.get_object_qs(instance.__class__)
if not queryset: if not queryset:
continue return
# The queryset we get from the provider must include the instance we've got given # The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider # otherwise ignore this provider
if not queryset.filter(pk=instance.pk).exists(): if not queryset.filter(pk=instance.pk).exists():
continue return
try: try:
if operation == Direction.add: if operation == Direction.add:
@ -208,28 +220,38 @@ class SyncTasks:
except TransientSyncException as exc: except TransientSyncException as exc:
raise Retry() from exc raise Retry() from exc
except SkipObjectException: except SkipObjectException:
continue return
except DryRunRejected as exc: except DryRunRejected as exc:
self.logger.info("Rejected dry-run event", exc=exc) self.logger.info("Rejected dry-run event", exc=exc)
except StopSync as exc: except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]): def sync_signal_m2m(
self,
group_pk: str,
provider_pk: int,
action: str,
pk_set: list[int],
):
self.logger = get_logger().bind( self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model), provider_type=class_to_path(self._provider_model),
) )
group = Group.objects.filter(pk=group_pk).first() group = Group.objects.filter(pk=group_pk).first()
if not group: if not group:
return return
for provider in self._provider_model.objects.filter( provider: OutgoingSyncProvider = self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False) Q(backchannel_application__isnull=False) | Q(application__isnull=False),
): pk=provider_pk,
).first()
if not provider:
return
# Check if the object is allowed within the provider's restrictions # Check if the object is allowed within the provider's restrictions
queryset: QuerySet = provider.get_object_qs(Group) queryset: QuerySet = provider.get_object_qs(Group)
# The queryset we get from the provider must include the instance we've got given # The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider # otherwise ignore this provider
if not queryset.filter(pk=group_pk).exists(): if not queryset.filter(pk=group_pk).exists():
continue return
client = provider.client_for_model(Group) client = provider.client_for_model(Group)
try: try:
@ -242,7 +264,7 @@ class SyncTasks:
except TransientSyncException as exc: except TransientSyncException as exc:
raise Retry() from exc raise Retry() from exc
except SkipObjectException: except SkipObjectException:
continue return
except DryRunRejected as exc: except DryRunRejected as exc:
self.logger.info("Rejected dry-run event", exc=exc) self.logger.info("Rejected dry-run event", exc=exc)
except StopSync as exc: except StopSync as exc: