providers/scim: fix time_limit not set correctly (#9546)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L
2024-05-03 11:48:34 +02:00
committed by GitHub
parent f1afc4d263
commit 79df24f4eb
5 changed files with 27 additions and 11 deletions

View File

@ -3,7 +3,7 @@
from structlog.stdlib import get_logger
from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_sync
from authentik.providers.scim.tasks import scim_task_wrapper
from authentik.tenants.management import TenantCommand
LOGGER = get_logger()
@ -21,4 +21,4 @@ class Command(TenantCommand):
if not provider:
LOGGER.warning("Provider does not exist", name=provider_name)
continue
scim_sync.delay(provider.pk).get()
scim_task_wrapper(provider.pk).get()

View File

@ -9,7 +9,7 @@ from structlog.stdlib import get_logger
from authentik.core.models import Group, User
from authentik.lib.utils.reflection import class_to_path
from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_signal_direct, scim_signal_m2m, scim_sync
from authentik.providers.scim.tasks import scim_signal_direct, scim_signal_m2m, scim_task_wrapper
LOGGER = get_logger()
@ -17,7 +17,7 @@ LOGGER = get_logger()
@receiver(post_save, sender=SCIMProvider)
def post_save_provider(sender: type[Model], instance, created: bool, **_):
"""Trigger sync when SCIM provider is saved"""
scim_sync.delay(instance.pk)
scim_task_wrapper(instance.pk)
@receiver(post_save, sender=User)

View File

@ -38,7 +38,23 @@ def client_for_model(provider: SCIMProvider, model: Model) -> SCIMClient:
def scim_sync_all():
"""Run sync for all providers"""
for provider in SCIMProvider.objects.filter(backchannel_application__isnull=False):
scim_sync.delay(provider.pk)
scim_task_wrapper(provider.pk)
def scim_task_wrapper(provider_pk: int):
"""Wrap scim_sync to set the correct timeouts"""
provider: SCIMProvider = SCIMProvider.objects.filter(
pk=provider_pk, backchannel_application__isnull=False
).first()
if not provider:
return
users_paginator = Paginator(provider.get_user_qs(), PAGE_SIZE)
groups_paginator = Paginator(provider.get_group_qs(), PAGE_SIZE)
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
time_limit = soft_time_limit * 1.5
return scim_sync.apply_async(
(provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
)
@CELERY_APP.task(bind=True, base=SystemTask)
@ -60,7 +76,7 @@ def scim_sync(self: SystemTask, provider_pk: int) -> None:
users_paginator = Paginator(provider.get_user_qs(), PAGE_SIZE)
groups_paginator = Paginator(provider.get_group_qs(), PAGE_SIZE)
self.soft_time_limit = self.time_limit = (
users_paginator.count + groups_paginator.count
users_paginator.num_pages + groups_paginator.num_pages
) * PAGE_TIMEOUT
with allow_join_result():
try:

View File

@ -8,7 +8,7 @@ from authentik.core.models import Application, Group, User
from authentik.lib.generators import generate_id
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
from authentik.providers.scim.models import SCIMMapping, SCIMProvider
from authentik.providers.scim.tasks import scim_sync
from authentik.providers.scim.tasks import scim_task_wrapper
from authentik.tenants.models import Tenant
@ -79,7 +79,7 @@ class SCIMMembershipTests(TestCase):
)
self.configure()
scim_sync.delay(self.provider.pk).get()
scim_task_wrapper(self.provider.pk).get()
self.assertEqual(mocker.call_count, 6)
self.assertEqual(mocker.request_history[0].method, "GET")
@ -169,7 +169,7 @@ class SCIMMembershipTests(TestCase):
)
self.configure()
scim_sync.delay(self.provider.pk).get()
scim_task_wrapper(self.provider.pk).get()
self.assertEqual(mocker.call_count, 6)
self.assertEqual(mocker.request_history[0].method, "GET")

View File

@ -10,7 +10,7 @@ from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application, Group, User
from authentik.lib.generators import generate_id
from authentik.providers.scim.models import SCIMMapping, SCIMProvider
from authentik.providers.scim.tasks import scim_sync
from authentik.providers.scim.tasks import scim_task_wrapper
from authentik.tenants.models import Tenant
@ -236,7 +236,7 @@ class SCIMUserTests(TestCase):
email=f"{uid}@goauthentik.io",
)
scim_sync.delay(self.provider.pk).get()
scim_task_wrapper(self.provider.pk).get()
self.assertEqual(mock.call_count, 5)
self.assertEqual(mock.request_history[0].method, "GET")