Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-04 17:44:23 +02:00
parent 5d7ba51872
commit 107b96e65c
2 changed files with 31 additions and 54 deletions

View File

@ -1,6 +1,8 @@
from typing import Any
from uuid import UUID
from django.http import HttpRequest from django.http import HttpRequest
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from dramatiq.actor import actor from dramatiq.actor import actor
from dramatiq.composition import group from dramatiq.composition import group
from requests.exceptions import RequestException from requests.exceptions import RequestException
@ -19,7 +21,6 @@ from authentik.lib.utils.http import get_http_session
from authentik.lib.utils.time import timedelta_from_string from authentik.lib.utils.time import timedelta_from_string
from authentik.policies.engine import PolicyEngine from authentik.policies.engine import PolicyEngine
from authentik.tasks.middleware import CurrentTask from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task, TaskStatus
session = get_http_session() session = get_http_session()
LOGGER = get_logger() LOGGER = get_logger()
@ -33,7 +34,6 @@ def send_ssf_event(
**extra_data, **extra_data,
): ):
"""Wrapper to send an SSF event to multiple streams""" """Wrapper to send an SSF event to multiple streams"""
payload = []
if not stream_filter: if not stream_filter:
stream_filter = {} stream_filter = {}
stream_filter["events_requested__contains"] = [event_type] stream_filter["events_requested__contains"] = [event_type]
@ -41,16 +41,12 @@ def send_ssf_event(
extra_data.setdefault("txn", request.request_id) extra_data.setdefault("txn", request.request_id)
for stream in Stream.objects.filter(**stream_filter): for stream in Stream.objects.filter(**stream_filter):
event_data = stream.prepare_event_payload(event_type, data, **extra_data) event_data = stream.prepare_event_payload(event_type, data, **extra_data)
payload.append((str(stream.uuid), event_data)) _send_ssf_event.send_with_options(args=(stream.uuid, event_data), rel_obj=stream.provider)
return _send_ssf_event.send(payload)
def _check_app_access(stream_uuid: str, event_data: dict) -> bool: def _check_app_access(stream: Stream, event_data: dict) -> bool:
"""Check if event is related to user and if so, check """Check if event is related to user and if so, check
if the user has access to the application""" if the user has access to the application"""
stream = Stream.objects.filter(pk=stream_uuid).first()
if not stream:
return False
# `event_data` is a dict version of a StreamEvent # `event_data` is a dict version of a StreamEvent
sub_id = event_data.get("payload", {}).get("sub_id", {}) sub_id = event_data.get("payload", {}).get("sub_id", {})
email = sub_id.get("user", {}).get("email", None) email = sub_id.get("user", {}).get("email", None)
@ -66,43 +62,21 @@ def _check_app_access(stream_uuid: str, event_data: dict) -> bool:
@actor @actor
def _send_ssf_event(event_data: list[tuple[str, dict]]): def _send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]):
tasks = [] self = CurrentTask.get_task()
for stream, data in event_data:
if not _check_app_access(stream, data):
continue
event = StreamEvent.objects.create(**data)
tasks.extend(send_single_ssf_event(stream, str(event.uuid)))
main_task = group(tasks)
main_task.run()
stream = Stream.objects.filter(pk=stream_uuid).first()
def send_single_ssf_event(stream_id: str, evt_id: str):
stream = Stream.objects.filter(pk=stream_id).first()
if not stream: if not stream:
return []
event = StreamEvent.objects.filter(pk=evt_id).first()
if not event:
return []
if event.status == SSFEventStatus.SENT:
return []
if stream.delivery_method == DeliveryMethods.RISC_PUSH:
return [ssf_push_event.message(str(event.pk))]
return []
@actor
def ssf_push_event(event_id: str):
self: Task = CurrentTask.get_task()
# TODO: fix me
# self.save_on_success = False
event = StreamEvent.objects.filter(pk=event_id).first()
if not event:
return return
self.set_uid(event_id) if not _check_app_access(stream, event_data):
if event.status == SSFEventStatus.SENT:
self.set_status(TaskStatus.SUCCESSFUL)
return return
event = StreamEvent.objects.create(**event_data)
self.set_uid(event.pk)
if event.status == SSFEventStatus.SENT:
return
if stream.delivery_method != DeliveryMethods.RISC_PUSH:
return
try: try:
response = session.post( response = session.post(
event.stream.endpoint_url, event.stream.endpoint_url,
@ -112,25 +86,21 @@ def ssf_push_event(event_id: str):
response.raise_for_status() response.raise_for_status()
event.status = SSFEventStatus.SENT event.status = SSFEventStatus.SENT
event.save() event.save()
self.set_status(TaskStatus.SUCCESSFUL)
return return
except RequestException as exc: except RequestException as exc:
LOGGER.warning("Failed to send SSF event", exc=exc) LOGGER.warning("Failed to send SSF event", exc=exc)
self.set_status(TaskStatus.ERROR)
attrs = {} attrs = {}
if exc.response: if exc.response:
attrs["response"] = { attrs["response"] = {
"content": exc.response.text, "content": exc.response.text,
"status": exc.response.status_code, "status": exc.response.status_code,
} }
self.set_error( self.error(
exc, exc,
LogEvent( LogEvent(
_("Failed to send request"), "Failed to send request",
log_level="warning", log_level="warning",
# TODO: fix me logger=self.uid,
# logger=self.__name__,
logger=str(self.uid),
attributes=attrs, attributes=attrs,
), ),
) )

View File

@ -1,5 +1,5 @@
from enum import StrEnum, auto from enum import StrEnum, auto
from uuid import uuid4 from uuid import UUID, uuid4
import pgtrigger import pgtrigger
from django.contrib.contenttypes.fields import ContentType, GenericForeignKey from django.contrib.contenttypes.fields import ContentType, GenericForeignKey
@ -63,7 +63,7 @@ class Task(SerializerModel):
rel_obj_id = models.TextField(null=True) rel_obj_id = models.TextField(null=True)
rel_obj = GenericForeignKey("rel_obj_content_type", "rel_obj_id") rel_obj = GenericForeignKey("rel_obj_content_type", "rel_obj_id")
uid = models.TextField(blank=True, null=True) _uid = models.TextField(blank=True, null=True)
messages = models.JSONField(default=list) messages = models.JSONField(default=list)
class Meta: class Meta:
@ -94,14 +94,21 @@ class Task(SerializerModel):
def __str__(self): def __str__(self):
return str(self.message_id) return str(self.message_id)
@property
def uid(self) -> str:
uid = str(self.actor_name)
if self._uid:
uid += f":{self._uid}"
return uid
@property @property
def serializer(self): def serializer(self):
from authentik.tasks.api import TaskSerializer from authentik.tasks.api import TaskSerializer
return TaskSerializer return TaskSerializer
def set_uid(self, uid: str, save: bool = False): def set_uid(self, uid: str | UUID, save: bool = False):
self.uid = uid self._uid = str(uid)
if save: if save:
self.save() self.save()
@ -112,7 +119,7 @@ class Task(SerializerModel):
if isinstance(message, Exception): if isinstance(message, Exception):
message = exception_to_string(message) message = exception_to_string(message)
if not isinstance(message, LogEvent): if not isinstance(message, LogEvent):
message = LogEvent(message, logger=self.actor_name, log_level=status.value) message = LogEvent(message, logger=self.uid, log_level=status.value)
self.messages.append(sanitize_item(message)) self.messages.append(sanitize_item(message))
if save: if save:
self.save() self.save()