diff --git a/authentik/root/settings.py b/authentik/root/settings.py index ce61b13b8e..62b7e5087b 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -401,7 +401,7 @@ DRAMATIQ = { "dramatiq.middleware.retries.Retries", {"max_retries": CONFIG.get_int("worker.task_max_retries") if not TEST else 0}, ), - # TODO: results + ("dramatiq.results.middleware.Results", {"store_results": True}), ("django_dramatiq_postgres.middleware.CurrentTask", {}), ("authentik.tasks.middleware.TenantMiddleware", {}), ("authentik.tasks.middleware.RelObjMiddleware", {}), diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/apps.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/apps.py index d43199bf1f..89873dfc14 100644 --- a/packages/django-dramatiq-postgres/django_dramatiq_postgres/apps.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/apps.py @@ -2,6 +2,7 @@ import dramatiq from django.apps import AppConfig from django.core.exceptions import ImproperlyConfigured from django.utils.module_loading import import_string +from dramatiq.results.middleware import Results from django_dramatiq_postgres.conf import Conf @@ -22,20 +23,21 @@ class DjangoDramatiqPostgres(AppConfig): encoder: dramatiq.encoder.Encoder = import_string(Conf().encoder_class)() dramatiq.set_encoder(encoder) - broker_args = Conf().broker_args - broker_kwargs = { - **Conf().broker_kwargs, - "middleware": [], - } broker: dramatiq.broker.Broker = import_string(Conf().broker_class)( - *broker_args, - **broker_kwargs, + *Conf().broker_args, + **Conf().broker_kwargs, + middleware=[], ) for middleware_class, middleware_kwargs in Conf().middlewares: middleware: dramatiq.middleware.middleware.Middleware = import_string(middleware_class)( **middleware_kwargs, ) + if isinstance(middleware, Results): + middleware.backend = import_string(Conf().result_backend)( + *Conf().result_backend_args, + **Conf().result_backend_kwargs, + ) broker.add_middleware(middleware) dramatiq.set_broker(broker) diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/conf.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/conf.py index 929d98cbec..e6bb2239cd 100644 --- a/packages/django-dramatiq-postgres/django_dramatiq_postgres/conf.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/conf.py @@ -66,6 +66,18 @@ class Conf: # 30 days return self.conf.get("task_expiration", 60 * 60 * 24 * 30) + @property + def result_backend(self) -> str: + return self.conf.get("result_backend", "django_dramatiq_postgres.results.PostgresBackend") + + @property + def result_backend_args(self) -> tuple[Any]: + return self.conf.get("result_backend_args", ()) + + @property + def result_backend_kwargs(self) -> dict[str, Any]: + return self.conf.get("result_backend_kwargs", {}) + @property def autodiscovery(self) -> dict[str, Any]: autodiscovery = { diff --git a/authentik/tasks/results.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/results.py similarity index 76% rename from authentik/tasks/results.py rename to packages/django-dramatiq-postgres/django_dramatiq_postgres/results.py index f397637532..faf36c14d3 100644 --- a/authentik/tasks/results.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/results.py @@ -1,10 +1,13 @@ from django.db import DEFAULT_DB_ALIAS from django.db.models import QuerySet from django.utils import timezone +from django.utils.functional import cached_property +from django.utils.module_loading import import_string from dramatiq.message import Message from dramatiq.results.backend import Missing, MResult, Result, ResultBackend -from authentik.tasks.models import Task +from django_dramatiq_postgres.conf import Conf +from django_dramatiq_postgres.models import TaskBase class PostgresBackend(ResultBackend): @@ -12,9 +15,13 @@ class PostgresBackend(ResultBackend): super().__init__(*args, **kwargs) self.db_alias = db_alias + @cached_property + def model(self) -> type[TaskBase]: + return import_string(Conf().task_model) + @property def query_set(self) -> QuerySet: - return Task.objects.using(self.db_alias) + return self.model.objects.using(self.db_alias) def build_message_key(self, message: Message) -> str: return str(message.message_id)