From bc6085adc7031eeb92db078b9e321860e1500a82 Mon Sep 17 00:00:00 2001 From: Marc 'risson' Schmitt Date: Tue, 1 Apr 2025 15:28:08 +0200 Subject: [PATCH] leftovers Signed-off-by: Marc 'risson' Schmitt --- authentik/events/signals.py | 4 +- authentik/events/tasks.py | 48 ++++++++----------- authentik/lib/sync/outgoing/api.py | 8 ++-- .../tests/test_webauthn.py | 6 +-- 4 files changed, 29 insertions(+), 37 deletions(-) diff --git a/authentik/events/signals.py b/authentik/events/signals.py index e3408c3b21..a73d44038e 100644 --- a/authentik/events/signals.py +++ b/authentik/events/signals.py @@ -114,14 +114,14 @@ def on_password_changed(sender, user: User, password: str, request: HttpRequest @receiver(post_save, sender=Event) def event_post_save_notification(sender, instance: Event, **_): """Start task to check if any policies trigger an notification on this event""" - event_notification_handler.delay(instance.event_uuid.hex) + event_notification_handler.send(instance.event_uuid.hex) @receiver(pre_delete, sender=User) def event_user_pre_delete_cleanup(sender, instance: User, **_): """If gdpr_compliance is enabled, remove all the user's events""" if get_current_tenant().gdpr_compliance: - gdpr_cleanup.delay(instance.pk) + gdpr_cleanup.send(instance.pk) @receiver(monitoring_set) diff --git a/authentik/events/tasks.py b/authentik/events/tasks.py index d923b068b5..9fbe26531c 100644 --- a/authentik/events/tasks.py +++ b/authentik/events/tasks.py @@ -1,6 +1,7 @@ """Event notification tasks""" from django.db.models.query_utils import Q +from dramatiq.actor import actor from guardian.shortcuts import get_anonymous_user from structlog.stdlib import get_logger @@ -12,24 +13,23 @@ from authentik.events.models import ( NotificationRule, NotificationTransport, NotificationTransportError, - TaskStatus, ) -from authentik.events.system_tasks import SystemTask, prefill_task from authentik.policies.engine import PolicyEngine from authentik.policies.models import PolicyBinding, PolicyEngineMode -from authentik.root.celery import CELERY_APP +from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import Task, TaskStatus LOGGER = get_logger() -@CELERY_APP.task() +@actor def event_notification_handler(event_uuid: str): """Start task for each trigger definition""" for trigger in NotificationRule.objects.all(): - event_trigger_handler.apply_async(args=[event_uuid, trigger.name], queue="authentik_events") + event_trigger_handler.send(event_uuid, trigger.name) -@CELERY_APP.task() +@actor def event_trigger_handler(event_uuid: str, trigger_name: str): """Check if policies attached to NotificationRule match event""" event: Event = Event.objects.filter(event_uuid=event_uuid).first() @@ -77,30 +77,22 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): for transport in trigger.transports.all(): for user in trigger.group.users.all(): LOGGER.debug("created notification") - notification_transport.apply_async( - args=[ - transport.pk, - str(event.pk), - user.pk, - str(trigger.pk), - ], - queue="authentik_events", + notification_transport.send( + transport.pk, + event.pk, + user.pk, + trigger.pk, ) if transport.send_once: break -@CELERY_APP.task( - bind=True, - autoretry_for=(NotificationTransportError,), - retry_backoff=True, - base=SystemTask, -) -def notification_transport( - self: SystemTask, transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str -): +@actor +def notification_transport(transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str): """Send notification over specified transport""" - self.save_on_success = False + self: Task = CurrentTask.get_task() + # TODO: fixme + # self.save_on_success = False try: event = Event.objects.filter(pk=event_pk).first() if not event: @@ -124,7 +116,7 @@ def notification_transport( raise exc -@CELERY_APP.task() +@actor def gdpr_cleanup(user_pk: int): """cleanup events from gdpr_compliance""" events = Event.objects.filter(user__pk=user_pk) @@ -132,10 +124,10 @@ def gdpr_cleanup(user_pk: int): events.delete() -@CELERY_APP.task(bind=True, base=SystemTask) -@prefill_task -def notification_cleanup(self: SystemTask): +@actor +def notification_cleanup(): """Cleanup seen notifications and notifications whose event expired.""" + self: Task = CurrentTask.get_task() notifications = Notification.objects.filter(Q(event=None) | Q(seen=True)) amount = notifications.count() notifications.delete() diff --git a/authentik/lib/sync/outgoing/api.py b/authentik/lib/sync/outgoing/api.py index ee6a3c8e03..a73239f3b1 100644 --- a/authentik/lib/sync/outgoing/api.py +++ b/authentik/lib/sync/outgoing/api.py @@ -1,5 +1,5 @@ -from celery import Task from django.utils.text import slugify +from dramatiq.actor import Actor from drf_spectacular.utils import OpenApiResponse, extend_schema from guardian.shortcuts import get_objects_for_user from rest_framework.decorators import action @@ -45,8 +45,8 @@ class SyncObjectResultSerializer(PassiveSerializer): class OutgoingSyncProviderStatusMixin: """Common API Endpoints for Outgoing sync providers""" - sync_single_task: type[Task] = None - sync_objects_task: type[Task] = None + sync_single_task: type[Actor] = None + sync_objects_task: type[Actor] = None @extend_schema( responses={ @@ -94,7 +94,7 @@ class OutgoingSyncProviderStatusMixin: provider: OutgoingSyncProvider = self.get_object() params = SyncObjectSerializer(data=request.data) params.is_valid(raise_exception=True) - res: list[LogEvent] = self.sync_objects_task.delay( + res: list[LogEvent] = self.sync_objects_task.send( params.validated_data["sync_object_model"], page=1, provider_pk=provider.pk, diff --git a/authentik/stages/authenticator_validate/tests/test_webauthn.py b/authentik/stages/authenticator_validate/tests/test_webauthn.py index 05fe216081..a2c1497d52 100644 --- a/authentik/stages/authenticator_validate/tests/test_webauthn.py +++ b/authentik/stages/authenticator_validate/tests/test_webauthn.py @@ -128,7 +128,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): def test_device_challenge_webauthn_restricted(self): """Test webauthn (getting device challenges with a webauthn device that is not allowed due to aaguid restrictions)""" - webauthn_mds_import.delay(force=True).get() + webauthn_mds_import.send(force=True).get_result() request = get_request("/") request.user = self.user @@ -245,7 +245,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): def test_validate_challenge_unrestricted(self): """Test webauthn authentication (unrestricted webauthn device)""" - webauthn_mds_import.delay(force=True).get() + webauthn_mds_import.send(force=True).get_result() device = WebAuthnDevice.objects.create( user=self.user, public_key=( @@ -319,7 +319,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): def test_validate_challenge_restricted(self): """Test webauthn authentication (restricted device type, failure)""" - webauthn_mds_import.delay(force=True).get() + webauthn_mds_import.send(force=True).get_result() device = WebAuthnDevice.objects.create( user=self.user, public_key=(