@ -20,27 +20,36 @@ def register_signals(
|
||||
|
||||
def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_):
|
||||
"""Post save handler"""
|
||||
if not provider_type.objects.filter(
|
||||
for provider in provider_type.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
).exists():
|
||||
return
|
||||
task_sync_direct.send(class_to_path(instance.__class__), instance.pk, Direction.add.value)
|
||||
):
|
||||
task_sync_direct.send_with_options(
|
||||
args=(
|
||||
class_to_path(instance.__class__),
|
||||
instance.pk,
|
||||
provider.pk,
|
||||
Direction.add.value,
|
||||
),
|
||||
rel_obj=provider,
|
||||
)
|
||||
|
||||
post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False)
|
||||
post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False)
|
||||
|
||||
def model_pre_delete(sender: type[Model], instance: User | Group, **_):
|
||||
"""Pre-delete handler"""
|
||||
if not provider_type.objects.filter(
|
||||
for provider in provider_type.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
).exists():
|
||||
return
|
||||
try:
|
||||
task_sync_direct.send(
|
||||
class_to_path(instance.__class__), instance.pk, Direction.remove.value
|
||||
).get_result(block=True)
|
||||
except ResultFailure:
|
||||
pass
|
||||
):
|
||||
task_sync_direct.send_with_options(
|
||||
args=(
|
||||
class_to_path(instance.__class__),
|
||||
instance.pk,
|
||||
provider.pk,
|
||||
Direction.remove.value,
|
||||
),
|
||||
rel_obj=provider,
|
||||
)
|
||||
|
||||
pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False)
|
||||
pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False)
|
||||
@ -51,16 +60,21 @@ def register_signals(
|
||||
"""Sync group membership"""
|
||||
if action not in ["post_add", "post_remove"]:
|
||||
return
|
||||
if not provider_type.objects.filter(
|
||||
for provider in provider_type.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
).exists():
|
||||
return
|
||||
# reverse: instance is a Group, pk_set is a list of user pks
|
||||
# non-reverse: instance is a User, pk_set is a list of groups
|
||||
if reverse:
|
||||
task_sync_m2m.send(instance.pk, action, list(pk_set))
|
||||
else:
|
||||
for group_pk in pk_set:
|
||||
task_sync_m2m.send(group_pk, action, [instance.pk])
|
||||
):
|
||||
# reverse: instance is a Group, pk_set is a list of user pks
|
||||
# non-reverse: instance is a User, pk_set is a list of groups
|
||||
if reverse:
|
||||
task_sync_m2m.send_with_options(
|
||||
args=(instance.pk, provider.pk, action, list(pk_set)),
|
||||
rel_obj=provider,
|
||||
)
|
||||
else:
|
||||
for group_pk in pk_set:
|
||||
task_sync_m2m.send_with_options(
|
||||
args=(group_pk, provider.pk, action, [instance.pk]),
|
||||
rel_obj=provider,
|
||||
)
|
||||
|
||||
m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False)
|
||||
|
||||
@ -122,7 +122,10 @@ class SyncTasks:
|
||||
object_type=object_type,
|
||||
)
|
||||
task.info(f"Syncing page {page} of {_object_type._meta.verbose_name_plural}")
|
||||
provider = self._provider_model.objects.filter(pk=provider_pk).first()
|
||||
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||
pk=provider_pk,
|
||||
).first()
|
||||
if not provider:
|
||||
return
|
||||
# Override dry run mode if requested, however don't save the provider
|
||||
@ -177,7 +180,13 @@ class SyncTasks:
|
||||
)
|
||||
break
|
||||
|
||||
def sync_signal_direct(self, model: str, pk: str | int, raw_op: str):
|
||||
def sync_signal_direct(
|
||||
self,
|
||||
model: str,
|
||||
pk: str | int,
|
||||
provider_pk: int,
|
||||
raw_op: str,
|
||||
):
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
)
|
||||
@ -185,65 +194,78 @@ class SyncTasks:
|
||||
instance = model_class.objects.filter(pk=pk).first()
|
||||
if not instance:
|
||||
return
|
||||
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||
pk=provider_pk,
|
||||
).first()
|
||||
if not provider:
|
||||
return
|
||||
operation = Direction(raw_op)
|
||||
for provider in self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
):
|
||||
client = provider.client_for_model(instance.__class__)
|
||||
# Check if the object is allowed within the provider's restrictions
|
||||
queryset = provider.get_object_qs(instance.__class__)
|
||||
if not queryset:
|
||||
continue
|
||||
client = provider.client_for_model(instance.__class__)
|
||||
# Check if the object is allowed within the provider's restrictions
|
||||
queryset = provider.get_object_qs(instance.__class__)
|
||||
if not queryset:
|
||||
return
|
||||
|
||||
# The queryset we get from the provider must include the instance we've got given
|
||||
# otherwise ignore this provider
|
||||
if not queryset.filter(pk=instance.pk).exists():
|
||||
continue
|
||||
# The queryset we get from the provider must include the instance we've got given
|
||||
# otherwise ignore this provider
|
||||
if not queryset.filter(pk=instance.pk).exists():
|
||||
return
|
||||
|
||||
try:
|
||||
if operation == Direction.add:
|
||||
client.write(instance)
|
||||
if operation == Direction.remove:
|
||||
client.delete(instance)
|
||||
except TransientSyncException as exc:
|
||||
raise Retry() from exc
|
||||
except SkipObjectException:
|
||||
continue
|
||||
except DryRunRejected as exc:
|
||||
self.logger.info("Rejected dry-run event", exc=exc)
|
||||
except StopSync as exc:
|
||||
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
|
||||
try:
|
||||
if operation == Direction.add:
|
||||
client.write(instance)
|
||||
if operation == Direction.remove:
|
||||
client.delete(instance)
|
||||
except TransientSyncException as exc:
|
||||
raise Retry() from exc
|
||||
except SkipObjectException:
|
||||
return
|
||||
except DryRunRejected as exc:
|
||||
self.logger.info("Rejected dry-run event", exc=exc)
|
||||
except StopSync as exc:
|
||||
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
|
||||
|
||||
def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]):
|
||||
def sync_signal_m2m(
|
||||
self,
|
||||
group_pk: str,
|
||||
provider_pk: int,
|
||||
action: str,
|
||||
pk_set: list[int],
|
||||
):
|
||||
self.logger = get_logger().bind(
|
||||
provider_type=class_to_path(self._provider_model),
|
||||
)
|
||||
group = Group.objects.filter(pk=group_pk).first()
|
||||
if not group:
|
||||
return
|
||||
for provider in self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||
):
|
||||
# Check if the object is allowed within the provider's restrictions
|
||||
queryset: QuerySet = provider.get_object_qs(Group)
|
||||
# The queryset we get from the provider must include the instance we've got given
|
||||
# otherwise ignore this provider
|
||||
if not queryset.filter(pk=group_pk).exists():
|
||||
continue
|
||||
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
|
||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||
pk=provider_pk,
|
||||
).first()
|
||||
if not provider:
|
||||
return
|
||||
|
||||
client = provider.client_for_model(Group)
|
||||
try:
|
||||
operation = None
|
||||
if action == "post_add":
|
||||
operation = Direction.add
|
||||
if action == "post_remove":
|
||||
operation = Direction.remove
|
||||
client.update_group(group, operation, pk_set)
|
||||
except TransientSyncException as exc:
|
||||
raise Retry() from exc
|
||||
except SkipObjectException:
|
||||
continue
|
||||
except DryRunRejected as exc:
|
||||
self.logger.info("Rejected dry-run event", exc=exc)
|
||||
except StopSync as exc:
|
||||
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
|
||||
# Check if the object is allowed within the provider's restrictions
|
||||
queryset: QuerySet = provider.get_object_qs(Group)
|
||||
# The queryset we get from the provider must include the instance we've got given
|
||||
# otherwise ignore this provider
|
||||
if not queryset.filter(pk=group_pk).exists():
|
||||
return
|
||||
|
||||
client = provider.client_for_model(Group)
|
||||
try:
|
||||
operation = None
|
||||
if action == "post_add":
|
||||
operation = Direction.add
|
||||
if action == "post_remove":
|
||||
operation = Direction.remove
|
||||
client.update_group(group, operation, pk_set)
|
||||
except TransientSyncException as exc:
|
||||
raise Retry() from exc
|
||||
except SkipObjectException:
|
||||
return
|
||||
except DryRunRejected as exc:
|
||||
self.logger.info("Rejected dry-run event", exc=exc)
|
||||
except StopSync as exc:
|
||||
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
|
||||
|
||||
Reference in New Issue
Block a user