@ -1,6 +1,8 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from django.http import HttpRequest
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import actor
|
||||
from dramatiq.composition import group
|
||||
from requests.exceptions import RequestException
|
||||
@ -19,7 +21,6 @@ from authentik.lib.utils.http import get_http_session
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
from authentik.tasks.models import Task, TaskStatus
|
||||
|
||||
session = get_http_session()
|
||||
LOGGER = get_logger()
|
||||
@ -33,7 +34,6 @@ def send_ssf_event(
|
||||
**extra_data,
|
||||
):
|
||||
"""Wrapper to send an SSF event to multiple streams"""
|
||||
payload = []
|
||||
if not stream_filter:
|
||||
stream_filter = {}
|
||||
stream_filter["events_requested__contains"] = [event_type]
|
||||
@ -41,16 +41,12 @@ def send_ssf_event(
|
||||
extra_data.setdefault("txn", request.request_id)
|
||||
for stream in Stream.objects.filter(**stream_filter):
|
||||
event_data = stream.prepare_event_payload(event_type, data, **extra_data)
|
||||
payload.append((str(stream.uuid), event_data))
|
||||
return _send_ssf_event.send(payload)
|
||||
_send_ssf_event.send_with_options(args=(stream.uuid, event_data), rel_obj=stream.provider)
|
||||
|
||||
|
||||
def _check_app_access(stream_uuid: str, event_data: dict) -> bool:
|
||||
def _check_app_access(stream: Stream, event_data: dict) -> bool:
|
||||
"""Check if event is related to user and if so, check
|
||||
if the user has access to the application"""
|
||||
stream = Stream.objects.filter(pk=stream_uuid).first()
|
||||
if not stream:
|
||||
return False
|
||||
# `event_data` is a dict version of a StreamEvent
|
||||
sub_id = event_data.get("payload", {}).get("sub_id", {})
|
||||
email = sub_id.get("user", {}).get("email", None)
|
||||
@ -66,43 +62,21 @@ def _check_app_access(stream_uuid: str, event_data: dict) -> bool:
|
||||
|
||||
|
||||
@actor
|
||||
def _send_ssf_event(event_data: list[tuple[str, dict]]):
|
||||
tasks = []
|
||||
for stream, data in event_data:
|
||||
if not _check_app_access(stream, data):
|
||||
continue
|
||||
event = StreamEvent.objects.create(**data)
|
||||
tasks.extend(send_single_ssf_event(stream, str(event.uuid)))
|
||||
main_task = group(tasks)
|
||||
main_task.run()
|
||||
def _send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]):
|
||||
self = CurrentTask.get_task()
|
||||
|
||||
|
||||
def send_single_ssf_event(stream_id: str, evt_id: str):
|
||||
stream = Stream.objects.filter(pk=stream_id).first()
|
||||
stream = Stream.objects.filter(pk=stream_uuid).first()
|
||||
if not stream:
|
||||
return []
|
||||
event = StreamEvent.objects.filter(pk=evt_id).first()
|
||||
if not event:
|
||||
return []
|
||||
if event.status == SSFEventStatus.SENT:
|
||||
return []
|
||||
if stream.delivery_method == DeliveryMethods.RISC_PUSH:
|
||||
return [ssf_push_event.message(str(event.pk))]
|
||||
return []
|
||||
|
||||
|
||||
@actor
|
||||
def ssf_push_event(event_id: str):
|
||||
self: Task = CurrentTask.get_task()
|
||||
# TODO: fix me
|
||||
# self.save_on_success = False
|
||||
event = StreamEvent.objects.filter(pk=event_id).first()
|
||||
if not event:
|
||||
return
|
||||
self.set_uid(event_id)
|
||||
if event.status == SSFEventStatus.SENT:
|
||||
self.set_status(TaskStatus.SUCCESSFUL)
|
||||
if not _check_app_access(stream, event_data):
|
||||
return
|
||||
event = StreamEvent.objects.create(**event_data)
|
||||
self.set_uid(event.pk)
|
||||
if event.status == SSFEventStatus.SENT:
|
||||
return
|
||||
if stream.delivery_method != DeliveryMethods.RISC_PUSH:
|
||||
return
|
||||
|
||||
try:
|
||||
response = session.post(
|
||||
event.stream.endpoint_url,
|
||||
@ -112,25 +86,21 @@ def ssf_push_event(event_id: str):
|
||||
response.raise_for_status()
|
||||
event.status = SSFEventStatus.SENT
|
||||
event.save()
|
||||
self.set_status(TaskStatus.SUCCESSFUL)
|
||||
return
|
||||
except RequestException as exc:
|
||||
LOGGER.warning("Failed to send SSF event", exc=exc)
|
||||
self.set_status(TaskStatus.ERROR)
|
||||
attrs = {}
|
||||
if exc.response:
|
||||
attrs["response"] = {
|
||||
"content": exc.response.text,
|
||||
"status": exc.response.status_code,
|
||||
}
|
||||
self.set_error(
|
||||
self.error(
|
||||
exc,
|
||||
LogEvent(
|
||||
_("Failed to send request"),
|
||||
"Failed to send request",
|
||||
log_level="warning",
|
||||
# TODO: fix me
|
||||
# logger=self.__name__,
|
||||
logger=str(self.uid),
|
||||
logger=self.uid,
|
||||
attributes=attrs,
|
||||
),
|
||||
)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from enum import StrEnum, auto
|
||||
from uuid import uuid4
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pgtrigger
|
||||
from django.contrib.contenttypes.fields import ContentType, GenericForeignKey
|
||||
@ -63,7 +63,7 @@ class Task(SerializerModel):
|
||||
rel_obj_id = models.TextField(null=True)
|
||||
rel_obj = GenericForeignKey("rel_obj_content_type", "rel_obj_id")
|
||||
|
||||
uid = models.TextField(blank=True, null=True)
|
||||
_uid = models.TextField(blank=True, null=True)
|
||||
messages = models.JSONField(default=list)
|
||||
|
||||
class Meta:
|
||||
@ -94,14 +94,21 @@ class Task(SerializerModel):
|
||||
def __str__(self):
|
||||
return str(self.message_id)
|
||||
|
||||
@property
|
||||
def uid(self) -> str:
|
||||
uid = str(self.actor_name)
|
||||
if self._uid:
|
||||
uid += f":{self._uid}"
|
||||
return uid
|
||||
|
||||
@property
|
||||
def serializer(self):
|
||||
from authentik.tasks.api import TaskSerializer
|
||||
|
||||
return TaskSerializer
|
||||
|
||||
def set_uid(self, uid: str, save: bool = False):
|
||||
self.uid = uid
|
||||
def set_uid(self, uid: str | UUID, save: bool = False):
|
||||
self._uid = str(uid)
|
||||
if save:
|
||||
self.save()
|
||||
|
||||
@ -112,7 +119,7 @@ class Task(SerializerModel):
|
||||
if isinstance(message, Exception):
|
||||
message = exception_to_string(message)
|
||||
if not isinstance(message, LogEvent):
|
||||
message = LogEvent(message, logger=self.actor_name, log_level=status.value)
|
||||
message = LogEvent(message, logger=self.uid, log_level=status.value)
|
||||
self.messages.append(sanitize_item(message))
|
||||
if save:
|
||||
self.save()
|
||||
|
||||
Reference in New Issue
Block a user