Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-04 17:44:23 +02:00
parent 5d7ba51872
commit 107b96e65c
2 changed files with 31 additions and 54 deletions

View File

@ -1,6 +1,8 @@
from typing import Any
from uuid import UUID
from django.http import HttpRequest
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from dramatiq.actor import actor
from dramatiq.composition import group
from requests.exceptions import RequestException
@ -19,7 +21,6 @@ from authentik.lib.utils.http import get_http_session
from authentik.lib.utils.time import timedelta_from_string
from authentik.policies.engine import PolicyEngine
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task, TaskStatus
session = get_http_session()
LOGGER = get_logger()
@ -33,7 +34,6 @@ def send_ssf_event(
**extra_data,
):
"""Wrapper to send an SSF event to multiple streams"""
payload = []
if not stream_filter:
stream_filter = {}
stream_filter["events_requested__contains"] = [event_type]
@ -41,16 +41,12 @@ def send_ssf_event(
extra_data.setdefault("txn", request.request_id)
for stream in Stream.objects.filter(**stream_filter):
event_data = stream.prepare_event_payload(event_type, data, **extra_data)
payload.append((str(stream.uuid), event_data))
return _send_ssf_event.send(payload)
_send_ssf_event.send_with_options(args=(stream.uuid, event_data), rel_obj=stream.provider)
def _check_app_access(stream_uuid: str, event_data: dict) -> bool:
def _check_app_access(stream: Stream, event_data: dict) -> bool:
"""Check if event is related to user and if so, check
if the user has access to the application"""
stream = Stream.objects.filter(pk=stream_uuid).first()
if not stream:
return False
# `event_data` is a dict version of a StreamEvent
sub_id = event_data.get("payload", {}).get("sub_id", {})
email = sub_id.get("user", {}).get("email", None)
@ -66,43 +62,21 @@ def _check_app_access(stream_uuid: str, event_data: dict) -> bool:
@actor
def _send_ssf_event(event_data: list[tuple[str, dict]]):
tasks = []
for stream, data in event_data:
if not _check_app_access(stream, data):
continue
event = StreamEvent.objects.create(**data)
tasks.extend(send_single_ssf_event(stream, str(event.uuid)))
main_task = group(tasks)
main_task.run()
def _send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]):
self = CurrentTask.get_task()
def send_single_ssf_event(stream_id: str, evt_id: str):
stream = Stream.objects.filter(pk=stream_id).first()
stream = Stream.objects.filter(pk=stream_uuid).first()
if not stream:
return []
event = StreamEvent.objects.filter(pk=evt_id).first()
if not event:
return []
if event.status == SSFEventStatus.SENT:
return []
if stream.delivery_method == DeliveryMethods.RISC_PUSH:
return [ssf_push_event.message(str(event.pk))]
return []
@actor
def ssf_push_event(event_id: str):
self: Task = CurrentTask.get_task()
# TODO: fix me
# self.save_on_success = False
event = StreamEvent.objects.filter(pk=event_id).first()
if not event:
return
self.set_uid(event_id)
if event.status == SSFEventStatus.SENT:
self.set_status(TaskStatus.SUCCESSFUL)
if not _check_app_access(stream, event_data):
return
event = StreamEvent.objects.create(**event_data)
self.set_uid(event.pk)
if event.status == SSFEventStatus.SENT:
return
if stream.delivery_method != DeliveryMethods.RISC_PUSH:
return
try:
response = session.post(
event.stream.endpoint_url,
@ -112,25 +86,21 @@ def ssf_push_event(event_id: str):
response.raise_for_status()
event.status = SSFEventStatus.SENT
event.save()
self.set_status(TaskStatus.SUCCESSFUL)
return
except RequestException as exc:
LOGGER.warning("Failed to send SSF event", exc=exc)
self.set_status(TaskStatus.ERROR)
attrs = {}
if exc.response:
attrs["response"] = {
"content": exc.response.text,
"status": exc.response.status_code,
}
self.set_error(
self.error(
exc,
LogEvent(
_("Failed to send request"),
"Failed to send request",
log_level="warning",
# TODO: fix me
# logger=self.__name__,
logger=str(self.uid),
logger=self.uid,
attributes=attrs,
),
)

View File

@ -1,5 +1,5 @@
from enum import StrEnum, auto
from uuid import uuid4
from uuid import UUID, uuid4
import pgtrigger
from django.contrib.contenttypes.fields import ContentType, GenericForeignKey
@ -63,7 +63,7 @@ class Task(SerializerModel):
rel_obj_id = models.TextField(null=True)
rel_obj = GenericForeignKey("rel_obj_content_type", "rel_obj_id")
uid = models.TextField(blank=True, null=True)
_uid = models.TextField(blank=True, null=True)
messages = models.JSONField(default=list)
class Meta:
@ -94,14 +94,21 @@ class Task(SerializerModel):
def __str__(self):
return str(self.message_id)
@property
def uid(self) -> str:
uid = str(self.actor_name)
if self._uid:
uid += f":{self._uid}"
return uid
@property
def serializer(self):
from authentik.tasks.api import TaskSerializer
return TaskSerializer
def set_uid(self, uid: str, save: bool = False):
self.uid = uid
def set_uid(self, uid: str | UUID, save: bool = False):
self._uid = str(uid)
if save:
self.save()
@ -112,7 +119,7 @@ class Task(SerializerModel):
if isinstance(message, Exception):
message = exception_to_string(message)
if not isinstance(message, LogEvent):
message = LogEvent(message, logger=self.actor_name, log_level=status.value)
message = LogEvent(message, logger=self.uid, log_level=status.value)
self.messages.append(sanitize_item(message))
if save:
self.save()