@ -20,27 +20,36 @@ def register_signals(
|
|||||||
|
|
||||||
def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_):
|
def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_):
|
||||||
"""Post save handler"""
|
"""Post save handler"""
|
||||||
if not provider_type.objects.filter(
|
for provider in provider_type.objects.filter(
|
||||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||||
).exists():
|
):
|
||||||
return
|
task_sync_direct.send_with_options(
|
||||||
task_sync_direct.send(class_to_path(instance.__class__), instance.pk, Direction.add.value)
|
args=(
|
||||||
|
class_to_path(instance.__class__),
|
||||||
|
instance.pk,
|
||||||
|
provider.pk,
|
||||||
|
Direction.add.value,
|
||||||
|
),
|
||||||
|
rel_obj=provider,
|
||||||
|
)
|
||||||
|
|
||||||
post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False)
|
post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False)
|
||||||
post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False)
|
post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False)
|
||||||
|
|
||||||
def model_pre_delete(sender: type[Model], instance: User | Group, **_):
|
def model_pre_delete(sender: type[Model], instance: User | Group, **_):
|
||||||
"""Pre-delete handler"""
|
"""Pre-delete handler"""
|
||||||
if not provider_type.objects.filter(
|
for provider in provider_type.objects.filter(
|
||||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||||
).exists():
|
):
|
||||||
return
|
task_sync_direct.send_with_options(
|
||||||
try:
|
args=(
|
||||||
task_sync_direct.send(
|
class_to_path(instance.__class__),
|
||||||
class_to_path(instance.__class__), instance.pk, Direction.remove.value
|
instance.pk,
|
||||||
).get_result(block=True)
|
provider.pk,
|
||||||
except ResultFailure:
|
Direction.remove.value,
|
||||||
pass
|
),
|
||||||
|
rel_obj=provider,
|
||||||
|
)
|
||||||
|
|
||||||
pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False)
|
pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False)
|
||||||
pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False)
|
pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False)
|
||||||
@ -51,16 +60,21 @@ def register_signals(
|
|||||||
"""Sync group membership"""
|
"""Sync group membership"""
|
||||||
if action not in ["post_add", "post_remove"]:
|
if action not in ["post_add", "post_remove"]:
|
||||||
return
|
return
|
||||||
if not provider_type.objects.filter(
|
for provider in provider_type.objects.filter(
|
||||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
||||||
).exists():
|
):
|
||||||
return
|
|
||||||
# reverse: instance is a Group, pk_set is a list of user pks
|
# reverse: instance is a Group, pk_set is a list of user pks
|
||||||
# non-reverse: instance is a User, pk_set is a list of groups
|
# non-reverse: instance is a User, pk_set is a list of groups
|
||||||
if reverse:
|
if reverse:
|
||||||
task_sync_m2m.send(instance.pk, action, list(pk_set))
|
task_sync_m2m.send_with_options(
|
||||||
|
args=(instance.pk, provider.pk, action, list(pk_set)),
|
||||||
|
rel_obj=provider,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
for group_pk in pk_set:
|
for group_pk in pk_set:
|
||||||
task_sync_m2m.send(group_pk, action, [instance.pk])
|
task_sync_m2m.send_with_options(
|
||||||
|
args=(group_pk, provider.pk, action, [instance.pk]),
|
||||||
|
rel_obj=provider,
|
||||||
|
)
|
||||||
|
|
||||||
m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False)
|
m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False)
|
||||||
|
|||||||
@ -122,7 +122,10 @@ class SyncTasks:
|
|||||||
object_type=object_type,
|
object_type=object_type,
|
||||||
)
|
)
|
||||||
task.info(f"Syncing page {page} of {_object_type._meta.verbose_name_plural}")
|
task.info(f"Syncing page {page} of {_object_type._meta.verbose_name_plural}")
|
||||||
provider = self._provider_model.objects.filter(pk=provider_pk).first()
|
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
|
||||||
|
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||||
|
pk=provider_pk,
|
||||||
|
).first()
|
||||||
if not provider:
|
if not provider:
|
||||||
return
|
return
|
||||||
# Override dry run mode if requested, however don't save the provider
|
# Override dry run mode if requested, however don't save the provider
|
||||||
@ -177,7 +180,13 @@ class SyncTasks:
|
|||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
def sync_signal_direct(self, model: str, pk: str | int, raw_op: str):
|
def sync_signal_direct(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
pk: str | int,
|
||||||
|
provider_pk: int,
|
||||||
|
raw_op: str,
|
||||||
|
):
|
||||||
self.logger = get_logger().bind(
|
self.logger = get_logger().bind(
|
||||||
provider_type=class_to_path(self._provider_model),
|
provider_type=class_to_path(self._provider_model),
|
||||||
)
|
)
|
||||||
@ -185,20 +194,23 @@ class SyncTasks:
|
|||||||
instance = model_class.objects.filter(pk=pk).first()
|
instance = model_class.objects.filter(pk=pk).first()
|
||||||
if not instance:
|
if not instance:
|
||||||
return
|
return
|
||||||
|
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
|
||||||
|
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||||
|
pk=provider_pk,
|
||||||
|
).first()
|
||||||
|
if not provider:
|
||||||
|
return
|
||||||
operation = Direction(raw_op)
|
operation = Direction(raw_op)
|
||||||
for provider in self._provider_model.objects.filter(
|
|
||||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
|
||||||
):
|
|
||||||
client = provider.client_for_model(instance.__class__)
|
client = provider.client_for_model(instance.__class__)
|
||||||
# Check if the object is allowed within the provider's restrictions
|
# Check if the object is allowed within the provider's restrictions
|
||||||
queryset = provider.get_object_qs(instance.__class__)
|
queryset = provider.get_object_qs(instance.__class__)
|
||||||
if not queryset:
|
if not queryset:
|
||||||
continue
|
return
|
||||||
|
|
||||||
# The queryset we get from the provider must include the instance we've got given
|
# The queryset we get from the provider must include the instance we've got given
|
||||||
# otherwise ignore this provider
|
# otherwise ignore this provider
|
||||||
if not queryset.filter(pk=instance.pk).exists():
|
if not queryset.filter(pk=instance.pk).exists():
|
||||||
continue
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if operation == Direction.add:
|
if operation == Direction.add:
|
||||||
@ -208,28 +220,38 @@ class SyncTasks:
|
|||||||
except TransientSyncException as exc:
|
except TransientSyncException as exc:
|
||||||
raise Retry() from exc
|
raise Retry() from exc
|
||||||
except SkipObjectException:
|
except SkipObjectException:
|
||||||
continue
|
return
|
||||||
except DryRunRejected as exc:
|
except DryRunRejected as exc:
|
||||||
self.logger.info("Rejected dry-run event", exc=exc)
|
self.logger.info("Rejected dry-run event", exc=exc)
|
||||||
except StopSync as exc:
|
except StopSync as exc:
|
||||||
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
|
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
|
||||||
|
|
||||||
def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]):
|
def sync_signal_m2m(
|
||||||
|
self,
|
||||||
|
group_pk: str,
|
||||||
|
provider_pk: int,
|
||||||
|
action: str,
|
||||||
|
pk_set: list[int],
|
||||||
|
):
|
||||||
self.logger = get_logger().bind(
|
self.logger = get_logger().bind(
|
||||||
provider_type=class_to_path(self._provider_model),
|
provider_type=class_to_path(self._provider_model),
|
||||||
)
|
)
|
||||||
group = Group.objects.filter(pk=group_pk).first()
|
group = Group.objects.filter(pk=group_pk).first()
|
||||||
if not group:
|
if not group:
|
||||||
return
|
return
|
||||||
for provider in self._provider_model.objects.filter(
|
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
|
||||||
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
|
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
|
||||||
):
|
pk=provider_pk,
|
||||||
|
).first()
|
||||||
|
if not provider:
|
||||||
|
return
|
||||||
|
|
||||||
# Check if the object is allowed within the provider's restrictions
|
# Check if the object is allowed within the provider's restrictions
|
||||||
queryset: QuerySet = provider.get_object_qs(Group)
|
queryset: QuerySet = provider.get_object_qs(Group)
|
||||||
# The queryset we get from the provider must include the instance we've got given
|
# The queryset we get from the provider must include the instance we've got given
|
||||||
# otherwise ignore this provider
|
# otherwise ignore this provider
|
||||||
if not queryset.filter(pk=group_pk).exists():
|
if not queryset.filter(pk=group_pk).exists():
|
||||||
continue
|
return
|
||||||
|
|
||||||
client = provider.client_for_model(Group)
|
client = provider.client_for_model(Group)
|
||||||
try:
|
try:
|
||||||
@ -242,7 +264,7 @@ class SyncTasks:
|
|||||||
except TransientSyncException as exc:
|
except TransientSyncException as exc:
|
||||||
raise Retry() from exc
|
raise Retry() from exc
|
||||||
except SkipObjectException:
|
except SkipObjectException:
|
||||||
continue
|
return
|
||||||
except DryRunRejected as exc:
|
except DryRunRejected as exc:
|
||||||
self.logger.info("Rejected dry-run event", exc=exc)
|
self.logger.info("Rejected dry-run event", exc=exc)
|
||||||
except StopSync as exc:
|
except StopSync as exc:
|
||||||
|
|||||||
Reference in New Issue
Block a user