diff --git a/authentik/root/settings.py b/authentik/root/settings.py index 91ff629d3b..f8976ecd74 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -67,6 +67,7 @@ SHARED_APPS = [ "pgactivity", "pglock", "channels", + "authentik.tasks", ] TENANT_APPS = [ "django.contrib.auth", @@ -122,7 +123,6 @@ TENANT_APPS = [ "authentik.stages.user_login", "authentik.stages.user_logout", "authentik.stages.user_write", - "authentik.tasks", "authentik.brands", "authentik.blueprints", "guardian", diff --git a/authentik/tasks/apps.py b/authentik/tasks/apps.py index ea69782e6f..246dbae274 100644 --- a/authentik/tasks/apps.py +++ b/authentik/tasks/apps.py @@ -1,4 +1,8 @@ +from datetime import timedelta +import dramatiq +from dramatiq.middleware import AgeLimit, Callbacks, Prometheus, Retries, TimeLimit from authentik.blueprints.apps import ManagedAppConfig +from authentik.tasks.encoder import JSONPickleEncoder class AuthentikTasksConfig(ManagedAppConfig): @@ -6,3 +10,16 @@ class AuthentikTasksConfig(ManagedAppConfig): label = "authentik_tasks" verbose_name = "authentik Tasks" default = True + + def ready(self) -> None: + from authentik.tasks.broker import PostgresBroker + + dramatiq.set_encoder(JSONPickleEncoder()) + broker = PostgresBroker() + broker.add_middleware(Prometheus()) + broker.add_middleware(AgeLimit(max_age=timedelta(days=30).total_seconds() * 1000)) + broker.add_middleware(TimeLimit()) + broker.add_middleware(Callbacks()) + broker.add_middleware(Retries(max_retries=3)) + dramatiq.set_broker(broker) + return super().ready() diff --git a/authentik/tasks/broker.py b/authentik/tasks/broker.py index c807347bd8..d2c2c2b61e 100644 --- a/authentik/tasks/broker.py +++ b/authentik/tasks/broker.py @@ -1,3 +1,4 @@ +from enum import Enum, StrEnum, auto import functools import logging import time @@ -23,12 +24,18 @@ from structlog.stdlib import get_logger from authentik.tasks.models import Queue as MQueue from authentik.tasks.results import PostgresBackend +from authentik.tenants.utils import get_current_tenant LOGGER = get_logger() -def channel_name(connection: DatabaseWrapper, queue_name: str, identifier: str) -> str: - return f"{connection.schema_name}.dramatiq.{queue_name}.{identifier}" +class ChannelIdentifier(StrEnum): + ENQUEUE = auto() + LOCK = auto() + + +def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str: + return f"authentik.tasks.{queue_name}.{identifier.value}" def raise_connection_error(func): @@ -117,6 +124,7 @@ class PostgresBroker(Broker): # TODO: update_or_create self.query_set.create( message_id=message.message_id, + tenant=get_current_tenant(), queue_name=message.queue_name, state=MQueue.State.QUEUED, message=message.encode(), @@ -196,9 +204,7 @@ class _PostgresConsumer(Consumer): # Should be set to True by Django by default self._listen_connection.set_autocommit(True) with self._listen_connection.cursor() as cursor: - cursor.execute( - "LISTEN %s", channel_name(self._listen_connection, self.queue_name, "enqueue") - ) + cursor.execute("LISTEN %s", channel_name(self.queue_name, ChannelIdentifier.ENQUEUE)) @raise_connection_error def ack(self, message: Message): @@ -240,7 +246,7 @@ class _PostgresConsumer(Consumer): notifies = self.query_set.filter( state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), queue_name=self.queue_name ) - channel = channel_name(self.connection, self.queue_name, "enqueue") + channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE) return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies] def _poll_for_notify(self): @@ -251,7 +257,7 @@ class _PostgresConsumer(Consumer): def _get_message_lock_id(self, message: Message) -> int: return _cast_lock_id( - f"{channel_name(connections[self.connection], self.queue_name, 'lock')}.{message.message_id}" # noqa: E501 + f"{channel_name(self.queue_name, ChannelIdentifier.LOCK)}.{message.message_id}" ) def _consume_one(self, message: Message) -> bool: diff --git a/authentik/tasks/encoder.py b/authentik/tasks/encoder.py new file mode 100644 index 0000000000..833b625d4b --- /dev/null +++ b/authentik/tasks/encoder.py @@ -0,0 +1,32 @@ +import jsonpickle +import dramatiq.encoder +from typing import Any +from dramatiq.encoder import MessageData +import orjson + + +class OrjsonBackend(jsonpickle.JSONBackend): + def encode(self, obj: Any, indent=None, separators=None) -> str: + return orjson.dumps(obj, option=orjson.OPT_NON_STR_KEYS).decode("utf-8") + + def decode(self, string: str) -> Any: + return orjson.loads(string) + + +class JSONPickleEncoder(dramatiq.encoder.Encoder): + def encode(self, data: MessageData) -> bytes: + return jsonpickle.encode( + data, + backend=OrjsonBackend(), + keys=True, + warn=True, + use_base85=True, + ).encode("utf-8") + + def decode(self, data: bytes) -> MessageData: + return jsonpickle.decode( + data.decode("utf-8"), + backend=OrjsonBackend(), + keys=True, + on_missing="warn", + ) diff --git a/authentik/tasks/models.py b/authentik/tasks/models.py index 59c93a2f4b..37fc5201b4 100644 --- a/authentik/tasks/models.py +++ b/authentik/tasks/models.py @@ -3,6 +3,8 @@ from uuid import uuid4 from django.db import models from django.utils import timezone +from authentik.tenants.models import Tenant + class Queue(models.Model): class State(models.TextChoices): @@ -11,6 +13,7 @@ class Queue(models.Model): REJECTED = "rejected" DONE = "done" + tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE) message_id = models.UUIDField(primary_key=True, default=uuid4, editable=False) queue_name = models.TextField(default="default") state = models.CharField(default=State.QUEUED, choices=State.choices)