From 79df24f4ebdbac8b5eee722432b75bdfb193046c Mon Sep 17 00:00:00 2001 From: Jens L Date: Fri, 3 May 2024 11:48:34 +0200 Subject: [PATCH] providers/scim: fix time_limit not set correctly (#9546) Signed-off-by: Jens Langhammer --- .../scim/management/commands/scim_sync.py | 4 ++-- authentik/providers/scim/signals.py | 4 ++-- authentik/providers/scim/tasks.py | 20 +++++++++++++++++-- .../providers/scim/tests/test_membership.py | 6 +++--- authentik/providers/scim/tests/test_user.py | 4 ++-- 5 files changed, 27 insertions(+), 11 deletions(-) diff --git a/authentik/providers/scim/management/commands/scim_sync.py b/authentik/providers/scim/management/commands/scim_sync.py index 25924c2a08..a82b1c0877 100644 --- a/authentik/providers/scim/management/commands/scim_sync.py +++ b/authentik/providers/scim/management/commands/scim_sync.py @@ -3,7 +3,7 @@ from structlog.stdlib import get_logger from authentik.providers.scim.models import SCIMProvider -from authentik.providers.scim.tasks import scim_sync +from authentik.providers.scim.tasks import scim_task_wrapper from authentik.tenants.management import TenantCommand LOGGER = get_logger() @@ -21,4 +21,4 @@ class Command(TenantCommand): if not provider: LOGGER.warning("Provider does not exist", name=provider_name) continue - scim_sync.delay(provider.pk).get() + scim_task_wrapper(provider.pk).get() diff --git a/authentik/providers/scim/signals.py b/authentik/providers/scim/signals.py index a40d36f3b5..f73a327bcc 100644 --- a/authentik/providers/scim/signals.py +++ b/authentik/providers/scim/signals.py @@ -9,7 +9,7 @@ from structlog.stdlib import get_logger from authentik.core.models import Group, User from authentik.lib.utils.reflection import class_to_path from authentik.providers.scim.models import SCIMProvider -from authentik.providers.scim.tasks import scim_signal_direct, scim_signal_m2m, scim_sync +from authentik.providers.scim.tasks import scim_signal_direct, scim_signal_m2m, scim_task_wrapper LOGGER = get_logger() @@ -17,7 +17,7 @@ LOGGER = get_logger() @receiver(post_save, sender=SCIMProvider) def post_save_provider(sender: type[Model], instance, created: bool, **_): """Trigger sync when SCIM provider is saved""" - scim_sync.delay(instance.pk) + scim_task_wrapper(instance.pk) @receiver(post_save, sender=User) diff --git a/authentik/providers/scim/tasks.py b/authentik/providers/scim/tasks.py index 15d0caea5d..ed79bbba6b 100644 --- a/authentik/providers/scim/tasks.py +++ b/authentik/providers/scim/tasks.py @@ -38,7 +38,23 @@ def client_for_model(provider: SCIMProvider, model: Model) -> SCIMClient: def scim_sync_all(): """Run sync for all providers""" for provider in SCIMProvider.objects.filter(backchannel_application__isnull=False): - scim_sync.delay(provider.pk) + scim_task_wrapper(provider.pk) + + +def scim_task_wrapper(provider_pk: int): + """Wrap scim_sync to set the correct timeouts""" + provider: SCIMProvider = SCIMProvider.objects.filter( + pk=provider_pk, backchannel_application__isnull=False + ).first() + if not provider: + return + users_paginator = Paginator(provider.get_user_qs(), PAGE_SIZE) + groups_paginator = Paginator(provider.get_group_qs(), PAGE_SIZE) + soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT + time_limit = soft_time_limit * 1.5 + return scim_sync.apply_async( + (provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit) + ) @CELERY_APP.task(bind=True, base=SystemTask) @@ -60,7 +76,7 @@ def scim_sync(self: SystemTask, provider_pk: int) -> None: users_paginator = Paginator(provider.get_user_qs(), PAGE_SIZE) groups_paginator = Paginator(provider.get_group_qs(), PAGE_SIZE) self.soft_time_limit = self.time_limit = ( - users_paginator.count + groups_paginator.count + users_paginator.num_pages + groups_paginator.num_pages ) * PAGE_TIMEOUT with allow_join_result(): try: diff --git a/authentik/providers/scim/tests/test_membership.py b/authentik/providers/scim/tests/test_membership.py index 54d69b4561..342001075a 100644 --- a/authentik/providers/scim/tests/test_membership.py +++ b/authentik/providers/scim/tests/test_membership.py @@ -8,7 +8,7 @@ from authentik.core.models import Application, Group, User from authentik.lib.generators import generate_id from authentik.providers.scim.clients.schema import ServiceProviderConfiguration from authentik.providers.scim.models import SCIMMapping, SCIMProvider -from authentik.providers.scim.tasks import scim_sync +from authentik.providers.scim.tasks import scim_task_wrapper from authentik.tenants.models import Tenant @@ -79,7 +79,7 @@ class SCIMMembershipTests(TestCase): ) self.configure() - scim_sync.delay(self.provider.pk).get() + scim_task_wrapper(self.provider.pk).get() self.assertEqual(mocker.call_count, 6) self.assertEqual(mocker.request_history[0].method, "GET") @@ -169,7 +169,7 @@ class SCIMMembershipTests(TestCase): ) self.configure() - scim_sync.delay(self.provider.pk).get() + scim_task_wrapper(self.provider.pk).get() self.assertEqual(mocker.call_count, 6) self.assertEqual(mocker.request_history[0].method, "GET") diff --git a/authentik/providers/scim/tests/test_user.py b/authentik/providers/scim/tests/test_user.py index 2f22e82b1d..bc1b3817f0 100644 --- a/authentik/providers/scim/tests/test_user.py +++ b/authentik/providers/scim/tests/test_user.py @@ -10,7 +10,7 @@ from authentik.blueprints.tests import apply_blueprint from authentik.core.models import Application, Group, User from authentik.lib.generators import generate_id from authentik.providers.scim.models import SCIMMapping, SCIMProvider -from authentik.providers.scim.tasks import scim_sync +from authentik.providers.scim.tasks import scim_task_wrapper from authentik.tenants.models import Tenant @@ -236,7 +236,7 @@ class SCIMUserTests(TestCase): email=f"{uid}@goauthentik.io", ) - scim_sync.delay(self.provider.pk).get() + scim_task_wrapper(self.provider.pk).get() self.assertEqual(mock.call_count, 5) self.assertEqual(mock.request_history[0].method, "GET")