From 16fd9cab674f26d5fd0746794c89846f3ae51941 Mon Sep 17 00:00:00 2001 From: Marc 'risson' Schmitt Date: Wed, 18 Jun 2025 17:01:05 +0200 Subject: [PATCH] move all broker stuff to package, schedule is still todo Signed-off-by: Marc 'risson' Schmitt --- authentik/root/settings.py | 19 +++-- authentik/tasks/apps.py | 69 +++--------------- authentik/tasks/broker.py | 62 +--------------- authentik/tasks/middleware.py | 15 ++++ .../django_dramatiq_postgres/apps.py | 10 +-- .../django_dramatiq_postgres/broker.py | 51 +++++--------- .../django_dramatiq_postgres/conf.py | 70 +++++++++++++------ .../django_dramatiq_postgres/middleware.py | 2 +- .../django_dramatiq_postgres/models.py | 2 +- 9 files changed, 113 insertions(+), 187 deletions(-) diff --git a/authentik/root/settings.py b/authentik/root/settings.py index 5e3fabf79f..a248dacae2 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -354,26 +354,33 @@ TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner" # Dramatiq DRAMATIQ = { + "broker_class": "authentik.tasks.broker.Broker", + "channel_prefix": "authentik.tasks", + "task_class": "authentik.tasks.models.Task", "middlewares": ( # TODO: fixme # ("dramatiq.middleware.prometheus.Prometheus", {}), + ("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}), ("dramatiq.middleware.age_limit.AgeLimit", {}), ( + # 5 minutes task timeout by default for all tasks, in ms "dramatiq.middleware.time_limit.TimeLimit", - { - # 5 minutes task timeout by default for all tasks - "time_limit": 600 * 1000, - }, + {"time_limit": 600_000}, ), ("dramatiq.middleware.shutdown.ShutdownNotifications", {}), ("dramatiq.middleware.callbacks.Callbacks", {}), ("dramatiq.middleware.pipelines.Pipelines", {}), - ("dramatiq.middleware.retries.Retries", {"max_retries": 20 if not TEST else 0}), + ( + "dramatiq.middleware.retries.Retries", + {"max_retries": 20 if not TEST else 0}, + ), # TODO: results ("authentik.tasks.middleware.FullyQualifiedActorName", {}), + ("authentik.tasks.middleware.RelObjMiddleware", {}), + ("authentik.tasks.middleware.TenantMiddleware", {}), ("authentik.tasks.middleware.CurrentTask", {}), ), - "task_class": "authentik.tasks.models.Task", + "test": TEST, } diff --git a/authentik/tasks/apps.py b/authentik/tasks/apps.py index a1255cc124..7a0f14810f 100644 --- a/authentik/tasks/apps.py +++ b/authentik/tasks/apps.py @@ -1,17 +1,3 @@ -import dramatiq -from dramatiq.broker import Broker, get_broker -from dramatiq.encoder import PickleEncoder -from dramatiq.middleware import ( - AgeLimit, - Callbacks, - Pipelines, - # Prometheus, - Retries, - ShutdownNotifications, - TimeLimit, -) -from dramatiq.results.middleware import Results - from authentik.blueprints.apps import ManagedAppConfig @@ -21,47 +7,14 @@ class AuthentikTasksConfig(ManagedAppConfig): verbose_name = "authentik Tasks" default = True - def _set_dramatiq_middlewares(self, broker: Broker, max_retries: int = 20) -> None: - from authentik.tasks.middleware import CurrentTask, FullyQualifiedActorName - from authentik.tasks.results import PostgresBackend - - # TODO: fixme - # broker.add_middleware(Prometheus()) - broker.add_middleware(AgeLimit()) - # Task timeout, 5 minutes by default for all tasks - broker.add_middleware(TimeLimit(time_limit=600 * 1000)) - broker.add_middleware(ShutdownNotifications()) - broker.add_middleware(Callbacks()) - broker.add_middleware(Pipelines()) - broker.add_middleware(Retries(max_retries=max_retries)) - broker.add_middleware(Results(backend=PostgresBackend(), store_results=True)) - - broker.add_middleware(FullyQualifiedActorName()) - broker.add_middleware(CurrentTask()) - - def ready(self) -> None: - from authentik.tasks.broker import PostgresBroker - - old_broker = dramatiq.get_broker() - if len(old_broker.actors) != 0: - raise RuntimeError("Mis-registered actors") - - dramatiq.set_encoder(PickleEncoder()) - - broker = PostgresBroker(middleware=[]) - self._set_dramatiq_middlewares(broker) - dramatiq.set_broker(broker) - - return super().ready() - - def use_test_broker(self) -> None: - from authentik.tasks.test import TestBroker - - old_broker = get_broker() - broker = TestBroker(middleware=[]) - self._set_dramatiq_middlewares(broker, max_retries=0) - dramatiq.set_broker(broker) - for actor_name in old_broker.get_declared_actors(): - actor = old_broker.get_actor(actor_name) - actor.broker = broker - actor.broker.declare_actor(actor) + # def use_test_broker(self) -> None: + # from authentik.tasks.test import TestBroker + # + # old_broker = get_broker() + # broker = TestBroker(middleware=[]) + # self._set_dramatiq_middlewares(broker, max_retries=0) + # dramatiq.set_broker(broker) + # for actor_name in old_broker.get_declared_actors(): + # actor = old_broker.get_actor(actor_name) + # actor.broker = broker + # actor.broker.declare_actor(actor) diff --git a/authentik/tasks/broker.py b/authentik/tasks/broker.py index 793e70b9dd..41711de4e7 100644 --- a/authentik/tasks/broker.py +++ b/authentik/tasks/broker.py @@ -1,73 +1,15 @@ from typing import Any -from django_dramatiq_postgres.middleware import DbConnectionMiddleware -from django.db import ( - DEFAULT_DB_ALIAS, -) -from dramatiq.broker import Broker +from django_dramatiq_postgres.broker import PostgresBroker from dramatiq.message import Message -from dramatiq.middleware import ( - AgeLimit, - Callbacks, - Middleware, - Pipelines, - Prometheus, - Retries, - ShutdownNotifications, - TimeLimit, -) from structlog.stdlib import get_logger -from django_dramatiq_postgres.broker import PostgresBroker as PostgresBrokerBase -from authentik.tasks.models import Task -from authentik.tenants.models import Tenant from authentik.tenants.utils import get_current_tenant LOGGER = get_logger() -class TenantMiddleware(Middleware): - def before_process_message(self, broker: Broker, message: Message): - Task.objects.select_related("tenant").get(message_id=message.message_id).tenant.activate() - - def after_process_message(self, *args, **kwargs): - Tenant.deactivate() - - -class PostgresBroker(PostgresBrokerBase): - def __init__( - self, - *args, - middleware: list[Middleware] | None = None, - db_alias: str = DEFAULT_DB_ALIAS, - **kwargs, - ): - super().__init__(*args, middleware=[], **kwargs) - self.logger = get_logger().bind() - - self.queues = set() - self.actor_options = { - "rel_obj", - } - - self.db_alias = db_alias - self.middleware = [] - self.add_middleware(DbConnectionMiddleware()) - self.add_middleware(TenantMiddleware()) - if middleware is None: - for m in ( - Prometheus, - AgeLimit, - TimeLimit, - ShutdownNotifications, - Callbacks, - Pipelines, - Retries, - ): - self.add_middleware(m()) - for m in middleware or []: - self.add_middleware(m) - +class Broker(PostgresBroker): def model_defaults(self, message: Message) -> dict[str, Any]: rel_obj = message.options.get("rel_obj") if rel_obj: diff --git a/authentik/tasks/middleware.py b/authentik/tasks/middleware.py index a8a359eefb..4fe6f8b080 100644 --- a/authentik/tasks/middleware.py +++ b/authentik/tasks/middleware.py @@ -8,10 +8,17 @@ from dramatiq.middleware import Middleware from structlog.stdlib import get_logger from authentik.tasks.models import Task +from authentik.tenants.models import Tenant LOGGER = get_logger() +class RelObjMiddleware(Middleware): + @property + def actor_options(self): + return {"rel_obj"} + + class FullyQualifiedActorName(Middleware): def before_declare_actor(self, broker: Broker, actor: Actor): actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}" @@ -52,3 +59,11 @@ class CurrentTask(Middleware): else: tasks[-1].save() self._TASK.set(tasks[:-1]) + + +class TenantMiddleware(Middleware): + def before_process_message(self, broker: Broker, message: Message): + Task.objects.select_related("tenant").get(message_id=message.message_id).tenant.activate() + + def after_process_message(self, *args, **kwargs): + Tenant.deactivate() diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/apps.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/apps.py index f6c870813c..d43199bf1f 100644 --- a/packages/django-dramatiq-postgres/django_dramatiq_postgres/apps.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/apps.py @@ -19,20 +19,20 @@ class DjangoDramatiqPostgres(AppConfig): "Make sure your actors are not imported too early." ) - encoder: dramatiq.encoder.Encoder = import_string(Conf.encoder_class)() + encoder: dramatiq.encoder.Encoder = import_string(Conf().encoder_class)() dramatiq.set_encoder(encoder) - broker_args = Conf.broker_args + broker_args = Conf().broker_args broker_kwargs = { - **Conf.broker_kwargs, + **Conf().broker_kwargs, "middleware": [], } - broker: dramatiq.broker.Broker = import_string(Conf.broker_class)( + broker: dramatiq.broker.Broker = import_string(Conf().broker_class)( *broker_args, **broker_kwargs, ) - for middleware_class, middleware_kwargs in Conf.middlewares: + for middleware_class, middleware_kwargs in Conf().middlewares: middleware: dramatiq.middleware.middleware.Middleware = import_string(middleware_class)( **middleware_kwargs, ) diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py index e7a9afb215..b9bdc4bf63 100644 --- a/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py @@ -7,6 +7,7 @@ from random import randint from typing import Any import tenacity +from django.core.exceptions import ImproperlyConfigured from django.db import ( DEFAULT_DB_ALIAS, DatabaseError, @@ -24,14 +25,7 @@ from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name from dramatiq.errors import ConnectionError, QueueJoinTimeout from dramatiq.message import Message from dramatiq.middleware import ( - AgeLimit, - Callbacks, Middleware, - Pipelines, - Prometheus, - Retries, - ShutdownNotifications, - TimeLimit, ) from pglock.core import _cast_lock_id from psycopg import Notify, sql @@ -39,7 +33,6 @@ from psycopg.errors import AdminShutdown from structlog.stdlib import get_logger from django_dramatiq_postgres.conf import Conf -from django_dramatiq_postgres.middleware import DbConnectionMiddleware from django_dramatiq_postgres.models import CHANNEL_PREFIX, ChannelIdentifier, TaskBase, TaskState LOGGER = get_logger() @@ -75,20 +68,10 @@ class PostgresBroker(Broker): self.db_alias = db_alias self.middleware = [] - if middleware is None: - for m in ( - DbConnectionMiddleware, - Prometheus, - AgeLimit, - TimeLimit, - ShutdownNotifications, - Callbacks, - Pipelines, - Retries, - ): - self.add_middleware(m()) - for m in middleware or []: - self.add_middleware(m) + if middleware: + raise ImproperlyConfigured( + "Middlewares should be set in django settings, not passed directly to the broker." + ) @property def connection(self) -> DatabaseWrapper: @@ -100,7 +83,7 @@ class PostgresBroker(Broker): @cached_property def model(self) -> type[TaskBase]: - return import_string(Conf.task_class) + return import_string(Conf().task_class) @property def query_set(self) -> QuerySet: @@ -267,7 +250,6 @@ class _PostgresConsumer(Consumer): @raise_connection_error def ack(self, message: Message): - self.unlock_queue.put_nowait(message) self.query_set.filter( message_id=message.message_id, queue_name=message.queue_name, @@ -276,11 +258,11 @@ class _PostgresConsumer(Consumer): state=TaskState.DONE, message=message.encode(), ) + self.unlock_queue.put_nowait(message.message_id) self.in_processing.remove(message.message_id) @raise_connection_error def nack(self, message: Message): - self.unlock_queue.put_nowait(message) self.query_set.filter( message_id=message.message_id, queue_name=message.queue_name, @@ -290,18 +272,18 @@ class _PostgresConsumer(Consumer): state=TaskState.REJECTED, message=message.encode(), ) + self.unlock_queue.put_nowait(message.message_id) self.in_processing.remove(message.message_id) @raise_connection_error def requeue(self, messages: Iterable[Message]): - for message in messages: - self.unlock_queue.put_nowait(message) self.query_set.filter( message_id__in=[message.message_id for message in messages], ).update( state=TaskState.QUEUED, ) for message in messages: + self.unlock_queue.put_nowait(message.message_id) self.in_processing.remove(message.message_id) self._purge_locks() @@ -328,9 +310,9 @@ class _PostgresConsumer(Consumer): ) self.notifies += notifies - def _get_message_lock_id(self, message: Message) -> int: + def _get_message_lock_id(self, message_id: str) -> int: return _cast_lock_id( - f"{channel_name(self.queue_name, ChannelIdentifier.LOCK)}.{message.message_id}" + f"{channel_name(self.queue_name, ChannelIdentifier.LOCK)}.{message_id}" ) def _consume_one(self, message: Message) -> bool: @@ -345,7 +327,7 @@ class _PostgresConsumer(Consumer): ) .extra( where=["pg_try_advisory_lock(%s)"], - params=[self._get_message_lock_id(message)], + params=[self._get_message_lock_id(message.message_id)], ) .update( state=TaskState.CONSUMED, @@ -391,6 +373,7 @@ class _PostgresConsumer(Consumer): notify = self.notifies.pop(0) task = self.query_set.get(message_id=notify.payload) message = Message.decode(task.message) + message.task = task if self._consume_one(message): self.in_processing.add(message.message_id) return MessageProxy(message) @@ -404,17 +387,18 @@ class _PostgresConsumer(Consumer): def _purge_locks(self): while True: try: - message = self.unlock_queue.get(block=False) + message_id = self.unlock_queue.get(block=False) except Empty: return - self.logger.debug(f"Unlocking {message.message_id}@{message.queue_name}") + self.logger.debug(f"Unlocking message {message_id}") with self.connection.cursor() as cursor: cursor.execute( - "SELECT pg_advisory_unlock(%s)", (self._get_message_lock_id(message),) + "SELECT pg_advisory_unlock(%s)", (self._get_message_lock_id(message_id),) ) self.unlock_queue.task_done() def _auto_purge(self): + # TODO: allow configuring this # Automatically purge messages on average every 100k iteration. # Dramatiq defaults to 1s, so this means one purge every 28 hours. if randint(0, 100_000): # nosec @@ -422,6 +406,7 @@ class _PostgresConsumer(Consumer): self.logger.debug("Running garbage collector") count = self.query_set.filter( state__in=(TaskState.DONE, TaskState.REJECTED), + # TODO: allow configuring this mtime__lte=timezone.now() - timezone.timedelta(days=30), result_expiry__lte=timezone.now(), ).delete() diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/conf.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/conf.py index c794814463..08297ad3b2 100644 --- a/packages/django-dramatiq-postgres/django_dramatiq_postgres/conf.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/conf.py @@ -1,33 +1,57 @@ +from typing import Any + from django.conf import settings +from django.core.exceptions import ImproperlyConfigured class Conf: - try: - conf = settings.DRAMATIQ.copy() - except AttributeError: - conf = {} + def __init__(self): + try: + self.conf = settings.DRAMATIQ + except AttributeError as exc: + raise ImproperlyConfigured("Setting DRAMATIQ not set.") from exc + if "task_class" not in self.conf: + raise ImproperlyConfigured("DRAMATIQ.task_class not defined") - encoder_class = conf.get("encoder_class", "dramatiq.encoder.PickleEncoder") + @property + def encoder_class(self) -> str: + return self.conf.get("encoder_class", "dramatiq.encoder.PickleEncoder") - broker_class = conf.get("broker_class", "django_dramatiq_postgres.broker.PostgresBroker") - broker_args = conf.get("broker_args", ()) - broker_kwargs = conf.get("broker_kwargs", {}) + @property + def broker_class(self) -> str: + return self.conf.get("broker_class", "django_dramatiq_postgres.broker.PostgresBroker") - middlewares = conf.get( - "middlewares", - ( - ("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}), - ("dramatiq.middleware.age_limit.AgeLimit", {}), - ("dramatiq.middleware.time_limit.TimeLimit", {}), - ("dramatiq.middleware.shutdown.ShutdownNotifications", {}), - ("dramatiq.middleware.callbacks.Callbacks", {}), - ("dramatiq.middleware.pipelines.Pipelines", {}), - ("dramatiq.middleware.retries.Retries", {}), - ), - ) + @property + def broker_args(self) -> tuple[Any]: + return self.conf.get("broker_args", ()) - channel_prefix = conf.get("channel_prefix", "dramatiq.tasks") + @property + def broker_kwargs(self) -> dict[str, Any]: + return self.conf.get("broker_kwargs", {}) - task_class = conf.get("task_class", None) + @property + def middlewares(self) -> tuple[tuple[str, dict[str, Any]]]: + return self.conf.get( + "middlewares", + ( + ("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}), + ("dramatiq.middleware.age_limit.AgeLimit", {}), + ("dramatiq.middleware.time_limit.TimeLimit", {}), + ("dramatiq.middleware.shutdown.ShutdownNotifications", {}), + ("dramatiq.middleware.callbacks.Callbacks", {}), + ("dramatiq.middleware.pipelines.Pipelines", {}), + ("dramatiq.middleware.retries.Retries", {}), + ), + ) - test = conf.get("test", False) + @property + def channel_prefix(self) -> str: + return self.conf.get("channel_prefix", "dramatiq.tasks") + + @property + def task_class(self) -> str: + return self.conf["task_class"] + + @property + def test(self) -> bool: + return self.conf.get("test", False) diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py index c96b82a9a9..0f931b7ebd 100644 --- a/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py @@ -9,7 +9,7 @@ from django_dramatiq_postgres.conf import Conf class DbConnectionMiddleware(Middleware): def _close_old_connections(self, *args, **kwargs): - if Conf.test: + if Conf().test: return close_old_connections() diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/models.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/models.py index 984408611d..5cf09dfac1 100644 --- a/packages/django-dramatiq-postgres/django_dramatiq_postgres/models.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/models.py @@ -8,7 +8,7 @@ from django.utils.translation import gettext_lazy as _ from django_dramatiq_postgres.conf import Conf -CHANNEL_PREFIX = f"{Conf.channel_prefix}.tasks" +CHANNEL_PREFIX = f"{Conf().channel_prefix}.tasks" class ChannelIdentifier(StrEnum):