From 107b96e65c4fd2f41778fb1ea19f4512e40a99e2 Mon Sep 17 00:00:00 2001 From: Marc 'risson' Schmitt Date: Wed, 4 Jun 2025 17:44:23 +0200 Subject: [PATCH] wip Signed-off-by: Marc 'risson' Schmitt --- authentik/enterprise/providers/ssf/tasks.py | 68 ++++++--------------- authentik/tasks/models.py | 17 ++++-- 2 files changed, 31 insertions(+), 54 deletions(-) diff --git a/authentik/enterprise/providers/ssf/tasks.py b/authentik/enterprise/providers/ssf/tasks.py index f842c90f6c..70800f8bc8 100644 --- a/authentik/enterprise/providers/ssf/tasks.py +++ b/authentik/enterprise/providers/ssf/tasks.py @@ -1,6 +1,8 @@ +from typing import Any +from uuid import UUID + from django.http import HttpRequest from django.utils.timezone import now -from django.utils.translation import gettext_lazy as _ from dramatiq.actor import actor from dramatiq.composition import group from requests.exceptions import RequestException @@ -19,7 +21,6 @@ from authentik.lib.utils.http import get_http_session from authentik.lib.utils.time import timedelta_from_string from authentik.policies.engine import PolicyEngine from authentik.tasks.middleware import CurrentTask -from authentik.tasks.models import Task, TaskStatus session = get_http_session() LOGGER = get_logger() @@ -33,7 +34,6 @@ def send_ssf_event( **extra_data, ): """Wrapper to send an SSF event to multiple streams""" - payload = [] if not stream_filter: stream_filter = {} stream_filter["events_requested__contains"] = [event_type] @@ -41,16 +41,12 @@ def send_ssf_event( extra_data.setdefault("txn", request.request_id) for stream in Stream.objects.filter(**stream_filter): event_data = stream.prepare_event_payload(event_type, data, **extra_data) - payload.append((str(stream.uuid), event_data)) - return _send_ssf_event.send(payload) + _send_ssf_event.send_with_options(args=(stream.uuid, event_data), rel_obj=stream.provider) -def _check_app_access(stream_uuid: str, event_data: dict) -> bool: +def _check_app_access(stream: Stream, event_data: dict) -> bool: """Check if event is related to user and if so, check if the user has access to the application""" - stream = Stream.objects.filter(pk=stream_uuid).first() - if not stream: - return False # `event_data` is a dict version of a StreamEvent sub_id = event_data.get("payload", {}).get("sub_id", {}) email = sub_id.get("user", {}).get("email", None) @@ -66,43 +62,21 @@ def _check_app_access(stream_uuid: str, event_data: dict) -> bool: @actor -def _send_ssf_event(event_data: list[tuple[str, dict]]): - tasks = [] - for stream, data in event_data: - if not _check_app_access(stream, data): - continue - event = StreamEvent.objects.create(**data) - tasks.extend(send_single_ssf_event(stream, str(event.uuid))) - main_task = group(tasks) - main_task.run() +def _send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]): + self = CurrentTask.get_task() - -def send_single_ssf_event(stream_id: str, evt_id: str): - stream = Stream.objects.filter(pk=stream_id).first() + stream = Stream.objects.filter(pk=stream_uuid).first() if not stream: - return [] - event = StreamEvent.objects.filter(pk=evt_id).first() - if not event: - return [] - if event.status == SSFEventStatus.SENT: - return [] - if stream.delivery_method == DeliveryMethods.RISC_PUSH: - return [ssf_push_event.message(str(event.pk))] - return [] - - -@actor -def ssf_push_event(event_id: str): - self: Task = CurrentTask.get_task() - # TODO: fix me - # self.save_on_success = False - event = StreamEvent.objects.filter(pk=event_id).first() - if not event: return - self.set_uid(event_id) - if event.status == SSFEventStatus.SENT: - self.set_status(TaskStatus.SUCCESSFUL) + if not _check_app_access(stream, event_data): return + event = StreamEvent.objects.create(**event_data) + self.set_uid(event.pk) + if event.status == SSFEventStatus.SENT: + return + if stream.delivery_method != DeliveryMethods.RISC_PUSH: + return + try: response = session.post( event.stream.endpoint_url, @@ -112,25 +86,21 @@ def ssf_push_event(event_id: str): response.raise_for_status() event.status = SSFEventStatus.SENT event.save() - self.set_status(TaskStatus.SUCCESSFUL) return except RequestException as exc: LOGGER.warning("Failed to send SSF event", exc=exc) - self.set_status(TaskStatus.ERROR) attrs = {} if exc.response: attrs["response"] = { "content": exc.response.text, "status": exc.response.status_code, } - self.set_error( + self.error( exc, LogEvent( - _("Failed to send request"), + "Failed to send request", log_level="warning", - # TODO: fix me - # logger=self.__name__, - logger=str(self.uid), + logger=self.uid, attributes=attrs, ), ) diff --git a/authentik/tasks/models.py b/authentik/tasks/models.py index 41f7d55032..8ed6c321c5 100644 --- a/authentik/tasks/models.py +++ b/authentik/tasks/models.py @@ -1,5 +1,5 @@ from enum import StrEnum, auto -from uuid import uuid4 +from uuid import UUID, uuid4 import pgtrigger from django.contrib.contenttypes.fields import ContentType, GenericForeignKey @@ -63,7 +63,7 @@ class Task(SerializerModel): rel_obj_id = models.TextField(null=True) rel_obj = GenericForeignKey("rel_obj_content_type", "rel_obj_id") - uid = models.TextField(blank=True, null=True) + _uid = models.TextField(blank=True, null=True) messages = models.JSONField(default=list) class Meta: @@ -94,14 +94,21 @@ class Task(SerializerModel): def __str__(self): return str(self.message_id) + @property + def uid(self) -> str: + uid = str(self.actor_name) + if self._uid: + uid += f":{self._uid}" + return uid + @property def serializer(self): from authentik.tasks.api import TaskSerializer return TaskSerializer - def set_uid(self, uid: str, save: bool = False): - self.uid = uid + def set_uid(self, uid: str | UUID, save: bool = False): + self._uid = str(uid) if save: self.save() @@ -112,7 +119,7 @@ class Task(SerializerModel): if isinstance(message, Exception): message = exception_to_string(message) if not isinstance(message, LogEvent): - message = LogEvent(message, logger=self.actor_name, log_level=status.value) + message = LogEvent(message, logger=self.uid, log_level=status.value) self.messages.append(sanitize_item(message)) if save: self.save()