leftovers

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-04-01 15:28:08 +02:00
parent d413e2875c
commit bc6085adc7
4 changed files with 29 additions and 37 deletions

View File

@ -114,14 +114,14 @@ def on_password_changed(sender, user: User, password: str, request: HttpRequest
@receiver(post_save, sender=Event) @receiver(post_save, sender=Event)
def event_post_save_notification(sender, instance: Event, **_): def event_post_save_notification(sender, instance: Event, **_):
"""Start task to check if any policies trigger an notification on this event""" """Start task to check if any policies trigger an notification on this event"""
event_notification_handler.delay(instance.event_uuid.hex) event_notification_handler.send(instance.event_uuid.hex)
@receiver(pre_delete, sender=User) @receiver(pre_delete, sender=User)
def event_user_pre_delete_cleanup(sender, instance: User, **_): def event_user_pre_delete_cleanup(sender, instance: User, **_):
"""If gdpr_compliance is enabled, remove all the user's events""" """If gdpr_compliance is enabled, remove all the user's events"""
if get_current_tenant().gdpr_compliance: if get_current_tenant().gdpr_compliance:
gdpr_cleanup.delay(instance.pk) gdpr_cleanup.send(instance.pk)
@receiver(monitoring_set) @receiver(monitoring_set)

View File

@ -1,6 +1,7 @@
"""Event notification tasks""" """Event notification tasks"""
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from dramatiq.actor import actor
from guardian.shortcuts import get_anonymous_user from guardian.shortcuts import get_anonymous_user
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -12,24 +13,23 @@ from authentik.events.models import (
NotificationRule, NotificationRule,
NotificationTransport, NotificationTransport,
NotificationTransportError, NotificationTransportError,
TaskStatus,
) )
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.policies.engine import PolicyEngine from authentik.policies.engine import PolicyEngine
from authentik.policies.models import PolicyBinding, PolicyEngineMode from authentik.policies.models import PolicyBinding, PolicyEngineMode
from authentik.root.celery import CELERY_APP from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task, TaskStatus
LOGGER = get_logger() LOGGER = get_logger()
@CELERY_APP.task() @actor
def event_notification_handler(event_uuid: str): def event_notification_handler(event_uuid: str):
"""Start task for each trigger definition""" """Start task for each trigger definition"""
for trigger in NotificationRule.objects.all(): for trigger in NotificationRule.objects.all():
event_trigger_handler.apply_async(args=[event_uuid, trigger.name], queue="authentik_events") event_trigger_handler.send(event_uuid, trigger.name)
@CELERY_APP.task() @actor
def event_trigger_handler(event_uuid: str, trigger_name: str): def event_trigger_handler(event_uuid: str, trigger_name: str):
"""Check if policies attached to NotificationRule match event""" """Check if policies attached to NotificationRule match event"""
event: Event = Event.objects.filter(event_uuid=event_uuid).first() event: Event = Event.objects.filter(event_uuid=event_uuid).first()
@ -77,30 +77,22 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
for transport in trigger.transports.all(): for transport in trigger.transports.all():
for user in trigger.group.users.all(): for user in trigger.group.users.all():
LOGGER.debug("created notification") LOGGER.debug("created notification")
notification_transport.apply_async( notification_transport.send(
args=[ transport.pk,
transport.pk, event.pk,
str(event.pk), user.pk,
user.pk, trigger.pk,
str(trigger.pk),
],
queue="authentik_events",
) )
if transport.send_once: if transport.send_once:
break break
@CELERY_APP.task( @actor
bind=True, def notification_transport(transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str):
autoretry_for=(NotificationTransportError,),
retry_backoff=True,
base=SystemTask,
)
def notification_transport(
self: SystemTask, transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str
):
"""Send notification over specified transport""" """Send notification over specified transport"""
self.save_on_success = False self: Task = CurrentTask.get_task()
# TODO: fixme
# self.save_on_success = False
try: try:
event = Event.objects.filter(pk=event_pk).first() event = Event.objects.filter(pk=event_pk).first()
if not event: if not event:
@ -124,7 +116,7 @@ def notification_transport(
raise exc raise exc
@CELERY_APP.task() @actor
def gdpr_cleanup(user_pk: int): def gdpr_cleanup(user_pk: int):
"""cleanup events from gdpr_compliance""" """cleanup events from gdpr_compliance"""
events = Event.objects.filter(user__pk=user_pk) events = Event.objects.filter(user__pk=user_pk)
@ -132,10 +124,10 @@ def gdpr_cleanup(user_pk: int):
events.delete() events.delete()
@CELERY_APP.task(bind=True, base=SystemTask) @actor
@prefill_task def notification_cleanup():
def notification_cleanup(self: SystemTask):
"""Cleanup seen notifications and notifications whose event expired.""" """Cleanup seen notifications and notifications whose event expired."""
self: Task = CurrentTask.get_task()
notifications = Notification.objects.filter(Q(event=None) | Q(seen=True)) notifications = Notification.objects.filter(Q(event=None) | Q(seen=True))
amount = notifications.count() amount = notifications.count()
notifications.delete() notifications.delete()

View File

@ -1,5 +1,5 @@
from celery import Task
from django.utils.text import slugify from django.utils.text import slugify
from dramatiq.actor import Actor
from drf_spectacular.utils import OpenApiResponse, extend_schema from drf_spectacular.utils import OpenApiResponse, extend_schema
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action from rest_framework.decorators import action
@ -45,8 +45,8 @@ class SyncObjectResultSerializer(PassiveSerializer):
class OutgoingSyncProviderStatusMixin: class OutgoingSyncProviderStatusMixin:
"""Common API Endpoints for Outgoing sync providers""" """Common API Endpoints for Outgoing sync providers"""
sync_single_task: type[Task] = None sync_single_task: type[Actor] = None
sync_objects_task: type[Task] = None sync_objects_task: type[Actor] = None
@extend_schema( @extend_schema(
responses={ responses={
@ -94,7 +94,7 @@ class OutgoingSyncProviderStatusMixin:
provider: OutgoingSyncProvider = self.get_object() provider: OutgoingSyncProvider = self.get_object()
params = SyncObjectSerializer(data=request.data) params = SyncObjectSerializer(data=request.data)
params.is_valid(raise_exception=True) params.is_valid(raise_exception=True)
res: list[LogEvent] = self.sync_objects_task.delay( res: list[LogEvent] = self.sync_objects_task.send(
params.validated_data["sync_object_model"], params.validated_data["sync_object_model"],
page=1, page=1,
provider_pk=provider.pk, provider_pk=provider.pk,

View File

@ -128,7 +128,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
def test_device_challenge_webauthn_restricted(self): def test_device_challenge_webauthn_restricted(self):
"""Test webauthn (getting device challenges with a webauthn """Test webauthn (getting device challenges with a webauthn
device that is not allowed due to aaguid restrictions)""" device that is not allowed due to aaguid restrictions)"""
webauthn_mds_import.delay(force=True).get() webauthn_mds_import.send(force=True).get_result()
request = get_request("/") request = get_request("/")
request.user = self.user request.user = self.user
@ -245,7 +245,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
def test_validate_challenge_unrestricted(self): def test_validate_challenge_unrestricted(self):
"""Test webauthn authentication (unrestricted webauthn device)""" """Test webauthn authentication (unrestricted webauthn device)"""
webauthn_mds_import.delay(force=True).get() webauthn_mds_import.send(force=True).get_result()
device = WebAuthnDevice.objects.create( device = WebAuthnDevice.objects.create(
user=self.user, user=self.user,
public_key=( public_key=(
@ -319,7 +319,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
def test_validate_challenge_restricted(self): def test_validate_challenge_restricted(self):
"""Test webauthn authentication (restricted device type, failure)""" """Test webauthn authentication (restricted device type, failure)"""
webauthn_mds_import.delay(force=True).get() webauthn_mds_import.send(force=True).get_result()
device = WebAuthnDevice.objects.create( device = WebAuthnDevice.objects.create(
user=self.user, user=self.user,
public_key=( public_key=(