Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-05 18:32:04 +02:00
parent 90debcdd70
commit 338da72622
2 changed files with 112 additions and 76 deletions

View File

@ -20,27 +20,36 @@ def register_signals(
def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_):
"""Post save handler"""
if not provider_type.objects.filter(
for provider in provider_type.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
).exists():
return
task_sync_direct.send(class_to_path(instance.__class__), instance.pk, Direction.add.value)
):
task_sync_direct.send_with_options(
args=(
class_to_path(instance.__class__),
instance.pk,
provider.pk,
Direction.add.value,
),
rel_obj=provider,
)
post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False)
post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False)
def model_pre_delete(sender: type[Model], instance: User | Group, **_):
"""Pre-delete handler"""
if not provider_type.objects.filter(
for provider in provider_type.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
).exists():
return
try:
task_sync_direct.send(
class_to_path(instance.__class__), instance.pk, Direction.remove.value
).get_result(block=True)
except ResultFailure:
pass
):
task_sync_direct.send_with_options(
args=(
class_to_path(instance.__class__),
instance.pk,
provider.pk,
Direction.remove.value,
),
rel_obj=provider,
)
pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False)
pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False)
@ -51,16 +60,21 @@ def register_signals(
"""Sync group membership"""
if action not in ["post_add", "post_remove"]:
return
if not provider_type.objects.filter(
for provider in provider_type.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
).exists():
return
# reverse: instance is a Group, pk_set is a list of user pks
# non-reverse: instance is a User, pk_set is a list of groups
if reverse:
task_sync_m2m.send(instance.pk, action, list(pk_set))
else:
for group_pk in pk_set:
task_sync_m2m.send(group_pk, action, [instance.pk])
):
# reverse: instance is a Group, pk_set is a list of user pks
# non-reverse: instance is a User, pk_set is a list of groups
if reverse:
task_sync_m2m.send_with_options(
args=(instance.pk, provider.pk, action, list(pk_set)),
rel_obj=provider,
)
else:
for group_pk in pk_set:
task_sync_m2m.send_with_options(
args=(group_pk, provider.pk, action, [instance.pk]),
rel_obj=provider,
)
m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False)

View File

@ -122,7 +122,10 @@ class SyncTasks:
object_type=object_type,
)
task.info(f"Syncing page {page} of {_object_type._meta.verbose_name_plural}")
provider = self._provider_model.objects.filter(pk=provider_pk).first()
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
pk=provider_pk,
).first()
if not provider:
return
# Override dry run mode if requested, however don't save the provider
@ -177,7 +180,13 @@ class SyncTasks:
)
break
def sync_signal_direct(self, model: str, pk: str | int, raw_op: str):
def sync_signal_direct(
self,
model: str,
pk: str | int,
provider_pk: int,
raw_op: str,
):
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
)
@ -185,65 +194,78 @@ class SyncTasks:
instance = model_class.objects.filter(pk=pk).first()
if not instance:
return
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
pk=provider_pk,
).first()
if not provider:
return
operation = Direction(raw_op)
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
client = provider.client_for_model(instance.__class__)
# Check if the object is allowed within the provider's restrictions
queryset = provider.get_object_qs(instance.__class__)
if not queryset:
continue
client = provider.client_for_model(instance.__class__)
# Check if the object is allowed within the provider's restrictions
queryset = provider.get_object_qs(instance.__class__)
if not queryset:
return
# The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider
if not queryset.filter(pk=instance.pk).exists():
continue
# The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider
if not queryset.filter(pk=instance.pk).exists():
return
try:
if operation == Direction.add:
client.write(instance)
if operation == Direction.remove:
client.delete(instance)
except TransientSyncException as exc:
raise Retry() from exc
except SkipObjectException:
continue
except DryRunRejected as exc:
self.logger.info("Rejected dry-run event", exc=exc)
except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
try:
if operation == Direction.add:
client.write(instance)
if operation == Direction.remove:
client.delete(instance)
except TransientSyncException as exc:
raise Retry() from exc
except SkipObjectException:
return
except DryRunRejected as exc:
self.logger.info("Rejected dry-run event", exc=exc)
except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]):
def sync_signal_m2m(
self,
group_pk: str,
provider_pk: int,
action: str,
pk_set: list[int],
):
self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model),
)
group = Group.objects.filter(pk=group_pk).first()
if not group:
return
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
# Check if the object is allowed within the provider's restrictions
queryset: QuerySet = provider.get_object_qs(Group)
# The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider
if not queryset.filter(pk=group_pk).exists():
continue
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
pk=provider_pk,
).first()
if not provider:
return
client = provider.client_for_model(Group)
try:
operation = None
if action == "post_add":
operation = Direction.add
if action == "post_remove":
operation = Direction.remove
client.update_group(group, operation, pk_set)
except TransientSyncException as exc:
raise Retry() from exc
except SkipObjectException:
continue
except DryRunRejected as exc:
self.logger.info("Rejected dry-run event", exc=exc)
except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
# Check if the object is allowed within the provider's restrictions
queryset: QuerySet = provider.get_object_qs(Group)
# The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider
if not queryset.filter(pk=group_pk).exists():
return
client = provider.client_for_model(Group)
try:
operation = None
if action == "post_add":
operation = Direction.add
if action == "post_remove":
operation = Direction.remove
client.update_group(group, operation, pk_set)
except TransientSyncException as exc:
raise Retry() from exc
except SkipObjectException:
return
except DryRunRejected as exc:
self.logger.info("Rejected dry-run event", exc=exc)
except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)