move all broker stuff to package, schedule is still todo

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-18 17:01:05 +02:00
parent 8c7818a252
commit 16fd9cab67
9 changed files with 113 additions and 187 deletions

View File

@ -354,26 +354,33 @@ TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner"
# Dramatiq # Dramatiq
DRAMATIQ = { DRAMATIQ = {
"broker_class": "authentik.tasks.broker.Broker",
"channel_prefix": "authentik.tasks",
"task_class": "authentik.tasks.models.Task",
"middlewares": ( "middlewares": (
# TODO: fixme # TODO: fixme
# ("dramatiq.middleware.prometheus.Prometheus", {}), # ("dramatiq.middleware.prometheus.Prometheus", {}),
("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}),
("dramatiq.middleware.age_limit.AgeLimit", {}), ("dramatiq.middleware.age_limit.AgeLimit", {}),
( (
# 5 minutes task timeout by default for all tasks, in ms
"dramatiq.middleware.time_limit.TimeLimit", "dramatiq.middleware.time_limit.TimeLimit",
{ {"time_limit": 600_000},
# 5 minutes task timeout by default for all tasks
"time_limit": 600 * 1000,
},
), ),
("dramatiq.middleware.shutdown.ShutdownNotifications", {}), ("dramatiq.middleware.shutdown.ShutdownNotifications", {}),
("dramatiq.middleware.callbacks.Callbacks", {}), ("dramatiq.middleware.callbacks.Callbacks", {}),
("dramatiq.middleware.pipelines.Pipelines", {}), ("dramatiq.middleware.pipelines.Pipelines", {}),
("dramatiq.middleware.retries.Retries", {"max_retries": 20 if not TEST else 0}), (
"dramatiq.middleware.retries.Retries",
{"max_retries": 20 if not TEST else 0},
),
# TODO: results # TODO: results
("authentik.tasks.middleware.FullyQualifiedActorName", {}), ("authentik.tasks.middleware.FullyQualifiedActorName", {}),
("authentik.tasks.middleware.RelObjMiddleware", {}),
("authentik.tasks.middleware.TenantMiddleware", {}),
("authentik.tasks.middleware.CurrentTask", {}), ("authentik.tasks.middleware.CurrentTask", {}),
), ),
"task_class": "authentik.tasks.models.Task", "test": TEST,
} }

View File

@ -1,17 +1,3 @@
import dramatiq
from dramatiq.broker import Broker, get_broker
from dramatiq.encoder import PickleEncoder
from dramatiq.middleware import (
AgeLimit,
Callbacks,
Pipelines,
# Prometheus,
Retries,
ShutdownNotifications,
TimeLimit,
)
from dramatiq.results.middleware import Results
from authentik.blueprints.apps import ManagedAppConfig from authentik.blueprints.apps import ManagedAppConfig
@ -21,47 +7,14 @@ class AuthentikTasksConfig(ManagedAppConfig):
verbose_name = "authentik Tasks" verbose_name = "authentik Tasks"
default = True default = True
def _set_dramatiq_middlewares(self, broker: Broker, max_retries: int = 20) -> None: # def use_test_broker(self) -> None:
from authentik.tasks.middleware import CurrentTask, FullyQualifiedActorName # from authentik.tasks.test import TestBroker
from authentik.tasks.results import PostgresBackend #
# old_broker = get_broker()
# TODO: fixme # broker = TestBroker(middleware=[])
# broker.add_middleware(Prometheus()) # self._set_dramatiq_middlewares(broker, max_retries=0)
broker.add_middleware(AgeLimit()) # dramatiq.set_broker(broker)
# Task timeout, 5 minutes by default for all tasks # for actor_name in old_broker.get_declared_actors():
broker.add_middleware(TimeLimit(time_limit=600 * 1000)) # actor = old_broker.get_actor(actor_name)
broker.add_middleware(ShutdownNotifications()) # actor.broker = broker
broker.add_middleware(Callbacks()) # actor.broker.declare_actor(actor)
broker.add_middleware(Pipelines())
broker.add_middleware(Retries(max_retries=max_retries))
broker.add_middleware(Results(backend=PostgresBackend(), store_results=True))
broker.add_middleware(FullyQualifiedActorName())
broker.add_middleware(CurrentTask())
def ready(self) -> None:
from authentik.tasks.broker import PostgresBroker
old_broker = dramatiq.get_broker()
if len(old_broker.actors) != 0:
raise RuntimeError("Mis-registered actors")
dramatiq.set_encoder(PickleEncoder())
broker = PostgresBroker(middleware=[])
self._set_dramatiq_middlewares(broker)
dramatiq.set_broker(broker)
return super().ready()
def use_test_broker(self) -> None:
from authentik.tasks.test import TestBroker
old_broker = get_broker()
broker = TestBroker(middleware=[])
self._set_dramatiq_middlewares(broker, max_retries=0)
dramatiq.set_broker(broker)
for actor_name in old_broker.get_declared_actors():
actor = old_broker.get_actor(actor_name)
actor.broker = broker
actor.broker.declare_actor(actor)

View File

@ -1,73 +1,15 @@
from typing import Any from typing import Any
from django_dramatiq_postgres.middleware import DbConnectionMiddleware from django_dramatiq_postgres.broker import PostgresBroker
from django.db import (
DEFAULT_DB_ALIAS,
)
from dramatiq.broker import Broker
from dramatiq.message import Message from dramatiq.message import Message
from dramatiq.middleware import (
AgeLimit,
Callbacks,
Middleware,
Pipelines,
Prometheus,
Retries,
ShutdownNotifications,
TimeLimit,
)
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from django_dramatiq_postgres.broker import PostgresBroker as PostgresBrokerBase
from authentik.tasks.models import Task
from authentik.tenants.models import Tenant
from authentik.tenants.utils import get_current_tenant from authentik.tenants.utils import get_current_tenant
LOGGER = get_logger() LOGGER = get_logger()
class TenantMiddleware(Middleware): class Broker(PostgresBroker):
def before_process_message(self, broker: Broker, message: Message):
Task.objects.select_related("tenant").get(message_id=message.message_id).tenant.activate()
def after_process_message(self, *args, **kwargs):
Tenant.deactivate()
class PostgresBroker(PostgresBrokerBase):
def __init__(
self,
*args,
middleware: list[Middleware] | None = None,
db_alias: str = DEFAULT_DB_ALIAS,
**kwargs,
):
super().__init__(*args, middleware=[], **kwargs)
self.logger = get_logger().bind()
self.queues = set()
self.actor_options = {
"rel_obj",
}
self.db_alias = db_alias
self.middleware = []
self.add_middleware(DbConnectionMiddleware())
self.add_middleware(TenantMiddleware())
if middleware is None:
for m in (
Prometheus,
AgeLimit,
TimeLimit,
ShutdownNotifications,
Callbacks,
Pipelines,
Retries,
):
self.add_middleware(m())
for m in middleware or []:
self.add_middleware(m)
def model_defaults(self, message: Message) -> dict[str, Any]: def model_defaults(self, message: Message) -> dict[str, Any]:
rel_obj = message.options.get("rel_obj") rel_obj = message.options.get("rel_obj")
if rel_obj: if rel_obj:

View File

@ -8,10 +8,17 @@ from dramatiq.middleware import Middleware
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.tasks.models import Task from authentik.tasks.models import Task
from authentik.tenants.models import Tenant
LOGGER = get_logger() LOGGER = get_logger()
class RelObjMiddleware(Middleware):
@property
def actor_options(self):
return {"rel_obj"}
class FullyQualifiedActorName(Middleware): class FullyQualifiedActorName(Middleware):
def before_declare_actor(self, broker: Broker, actor: Actor): def before_declare_actor(self, broker: Broker, actor: Actor):
actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}" actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}"
@ -52,3 +59,11 @@ class CurrentTask(Middleware):
else: else:
tasks[-1].save() tasks[-1].save()
self._TASK.set(tasks[:-1]) self._TASK.set(tasks[:-1])
class TenantMiddleware(Middleware):
def before_process_message(self, broker: Broker, message: Message):
Task.objects.select_related("tenant").get(message_id=message.message_id).tenant.activate()
def after_process_message(self, *args, **kwargs):
Tenant.deactivate()

View File

@ -19,20 +19,20 @@ class DjangoDramatiqPostgres(AppConfig):
"Make sure your actors are not imported too early." "Make sure your actors are not imported too early."
) )
encoder: dramatiq.encoder.Encoder = import_string(Conf.encoder_class)() encoder: dramatiq.encoder.Encoder = import_string(Conf().encoder_class)()
dramatiq.set_encoder(encoder) dramatiq.set_encoder(encoder)
broker_args = Conf.broker_args broker_args = Conf().broker_args
broker_kwargs = { broker_kwargs = {
**Conf.broker_kwargs, **Conf().broker_kwargs,
"middleware": [], "middleware": [],
} }
broker: dramatiq.broker.Broker = import_string(Conf.broker_class)( broker: dramatiq.broker.Broker = import_string(Conf().broker_class)(
*broker_args, *broker_args,
**broker_kwargs, **broker_kwargs,
) )
for middleware_class, middleware_kwargs in Conf.middlewares: for middleware_class, middleware_kwargs in Conf().middlewares:
middleware: dramatiq.middleware.middleware.Middleware = import_string(middleware_class)( middleware: dramatiq.middleware.middleware.Middleware = import_string(middleware_class)(
**middleware_kwargs, **middleware_kwargs,
) )

View File

@ -7,6 +7,7 @@ from random import randint
from typing import Any from typing import Any
import tenacity import tenacity
from django.core.exceptions import ImproperlyConfigured
from django.db import ( from django.db import (
DEFAULT_DB_ALIAS, DEFAULT_DB_ALIAS,
DatabaseError, DatabaseError,
@ -24,14 +25,7 @@ from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
from dramatiq.errors import ConnectionError, QueueJoinTimeout from dramatiq.errors import ConnectionError, QueueJoinTimeout
from dramatiq.message import Message from dramatiq.message import Message
from dramatiq.middleware import ( from dramatiq.middleware import (
AgeLimit,
Callbacks,
Middleware, Middleware,
Pipelines,
Prometheus,
Retries,
ShutdownNotifications,
TimeLimit,
) )
from pglock.core import _cast_lock_id from pglock.core import _cast_lock_id
from psycopg import Notify, sql from psycopg import Notify, sql
@ -39,7 +33,6 @@ from psycopg.errors import AdminShutdown
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from django_dramatiq_postgres.conf import Conf from django_dramatiq_postgres.conf import Conf
from django_dramatiq_postgres.middleware import DbConnectionMiddleware
from django_dramatiq_postgres.models import CHANNEL_PREFIX, ChannelIdentifier, TaskBase, TaskState from django_dramatiq_postgres.models import CHANNEL_PREFIX, ChannelIdentifier, TaskBase, TaskState
LOGGER = get_logger() LOGGER = get_logger()
@ -75,20 +68,10 @@ class PostgresBroker(Broker):
self.db_alias = db_alias self.db_alias = db_alias
self.middleware = [] self.middleware = []
if middleware is None: if middleware:
for m in ( raise ImproperlyConfigured(
DbConnectionMiddleware, "Middlewares should be set in django settings, not passed directly to the broker."
Prometheus, )
AgeLimit,
TimeLimit,
ShutdownNotifications,
Callbacks,
Pipelines,
Retries,
):
self.add_middleware(m())
for m in middleware or []:
self.add_middleware(m)
@property @property
def connection(self) -> DatabaseWrapper: def connection(self) -> DatabaseWrapper:
@ -100,7 +83,7 @@ class PostgresBroker(Broker):
@cached_property @cached_property
def model(self) -> type[TaskBase]: def model(self) -> type[TaskBase]:
return import_string(Conf.task_class) return import_string(Conf().task_class)
@property @property
def query_set(self) -> QuerySet: def query_set(self) -> QuerySet:
@ -267,7 +250,6 @@ class _PostgresConsumer(Consumer):
@raise_connection_error @raise_connection_error
def ack(self, message: Message): def ack(self, message: Message):
self.unlock_queue.put_nowait(message)
self.query_set.filter( self.query_set.filter(
message_id=message.message_id, message_id=message.message_id,
queue_name=message.queue_name, queue_name=message.queue_name,
@ -276,11 +258,11 @@ class _PostgresConsumer(Consumer):
state=TaskState.DONE, state=TaskState.DONE,
message=message.encode(), message=message.encode(),
) )
self.unlock_queue.put_nowait(message.message_id)
self.in_processing.remove(message.message_id) self.in_processing.remove(message.message_id)
@raise_connection_error @raise_connection_error
def nack(self, message: Message): def nack(self, message: Message):
self.unlock_queue.put_nowait(message)
self.query_set.filter( self.query_set.filter(
message_id=message.message_id, message_id=message.message_id,
queue_name=message.queue_name, queue_name=message.queue_name,
@ -290,18 +272,18 @@ class _PostgresConsumer(Consumer):
state=TaskState.REJECTED, state=TaskState.REJECTED,
message=message.encode(), message=message.encode(),
) )
self.unlock_queue.put_nowait(message.message_id)
self.in_processing.remove(message.message_id) self.in_processing.remove(message.message_id)
@raise_connection_error @raise_connection_error
def requeue(self, messages: Iterable[Message]): def requeue(self, messages: Iterable[Message]):
for message in messages:
self.unlock_queue.put_nowait(message)
self.query_set.filter( self.query_set.filter(
message_id__in=[message.message_id for message in messages], message_id__in=[message.message_id for message in messages],
).update( ).update(
state=TaskState.QUEUED, state=TaskState.QUEUED,
) )
for message in messages: for message in messages:
self.unlock_queue.put_nowait(message.message_id)
self.in_processing.remove(message.message_id) self.in_processing.remove(message.message_id)
self._purge_locks() self._purge_locks()
@ -328,9 +310,9 @@ class _PostgresConsumer(Consumer):
) )
self.notifies += notifies self.notifies += notifies
def _get_message_lock_id(self, message: Message) -> int: def _get_message_lock_id(self, message_id: str) -> int:
return _cast_lock_id( return _cast_lock_id(
f"{channel_name(self.queue_name, ChannelIdentifier.LOCK)}.{message.message_id}" f"{channel_name(self.queue_name, ChannelIdentifier.LOCK)}.{message_id}"
) )
def _consume_one(self, message: Message) -> bool: def _consume_one(self, message: Message) -> bool:
@ -345,7 +327,7 @@ class _PostgresConsumer(Consumer):
) )
.extra( .extra(
where=["pg_try_advisory_lock(%s)"], where=["pg_try_advisory_lock(%s)"],
params=[self._get_message_lock_id(message)], params=[self._get_message_lock_id(message.message_id)],
) )
.update( .update(
state=TaskState.CONSUMED, state=TaskState.CONSUMED,
@ -391,6 +373,7 @@ class _PostgresConsumer(Consumer):
notify = self.notifies.pop(0) notify = self.notifies.pop(0)
task = self.query_set.get(message_id=notify.payload) task = self.query_set.get(message_id=notify.payload)
message = Message.decode(task.message) message = Message.decode(task.message)
message.task = task
if self._consume_one(message): if self._consume_one(message):
self.in_processing.add(message.message_id) self.in_processing.add(message.message_id)
return MessageProxy(message) return MessageProxy(message)
@ -404,17 +387,18 @@ class _PostgresConsumer(Consumer):
def _purge_locks(self): def _purge_locks(self):
while True: while True:
try: try:
message = self.unlock_queue.get(block=False) message_id = self.unlock_queue.get(block=False)
except Empty: except Empty:
return return
self.logger.debug(f"Unlocking {message.message_id}@{message.queue_name}") self.logger.debug(f"Unlocking message {message_id}")
with self.connection.cursor() as cursor: with self.connection.cursor() as cursor:
cursor.execute( cursor.execute(
"SELECT pg_advisory_unlock(%s)", (self._get_message_lock_id(message),) "SELECT pg_advisory_unlock(%s)", (self._get_message_lock_id(message_id),)
) )
self.unlock_queue.task_done() self.unlock_queue.task_done()
def _auto_purge(self): def _auto_purge(self):
# TODO: allow configuring this
# Automatically purge messages on average every 100k iteration. # Automatically purge messages on average every 100k iteration.
# Dramatiq defaults to 1s, so this means one purge every 28 hours. # Dramatiq defaults to 1s, so this means one purge every 28 hours.
if randint(0, 100_000): # nosec if randint(0, 100_000): # nosec
@ -422,6 +406,7 @@ class _PostgresConsumer(Consumer):
self.logger.debug("Running garbage collector") self.logger.debug("Running garbage collector")
count = self.query_set.filter( count = self.query_set.filter(
state__in=(TaskState.DONE, TaskState.REJECTED), state__in=(TaskState.DONE, TaskState.REJECTED),
# TODO: allow configuring this
mtime__lte=timezone.now() - timezone.timedelta(days=30), mtime__lte=timezone.now() - timezone.timedelta(days=30),
result_expiry__lte=timezone.now(), result_expiry__lte=timezone.now(),
).delete() ).delete()

View File

@ -1,33 +1,57 @@
from typing import Any
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
class Conf: class Conf:
try: def __init__(self):
conf = settings.DRAMATIQ.copy() try:
except AttributeError: self.conf = settings.DRAMATIQ
conf = {} except AttributeError as exc:
raise ImproperlyConfigured("Setting DRAMATIQ not set.") from exc
if "task_class" not in self.conf:
raise ImproperlyConfigured("DRAMATIQ.task_class not defined")
encoder_class = conf.get("encoder_class", "dramatiq.encoder.PickleEncoder") @property
def encoder_class(self) -> str:
return self.conf.get("encoder_class", "dramatiq.encoder.PickleEncoder")
broker_class = conf.get("broker_class", "django_dramatiq_postgres.broker.PostgresBroker") @property
broker_args = conf.get("broker_args", ()) def broker_class(self) -> str:
broker_kwargs = conf.get("broker_kwargs", {}) return self.conf.get("broker_class", "django_dramatiq_postgres.broker.PostgresBroker")
middlewares = conf.get( @property
"middlewares", def broker_args(self) -> tuple[Any]:
( return self.conf.get("broker_args", ())
("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}),
("dramatiq.middleware.age_limit.AgeLimit", {}),
("dramatiq.middleware.time_limit.TimeLimit", {}),
("dramatiq.middleware.shutdown.ShutdownNotifications", {}),
("dramatiq.middleware.callbacks.Callbacks", {}),
("dramatiq.middleware.pipelines.Pipelines", {}),
("dramatiq.middleware.retries.Retries", {}),
),
)
channel_prefix = conf.get("channel_prefix", "dramatiq.tasks") @property
def broker_kwargs(self) -> dict[str, Any]:
return self.conf.get("broker_kwargs", {})
task_class = conf.get("task_class", None) @property
def middlewares(self) -> tuple[tuple[str, dict[str, Any]]]:
return self.conf.get(
"middlewares",
(
("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}),
("dramatiq.middleware.age_limit.AgeLimit", {}),
("dramatiq.middleware.time_limit.TimeLimit", {}),
("dramatiq.middleware.shutdown.ShutdownNotifications", {}),
("dramatiq.middleware.callbacks.Callbacks", {}),
("dramatiq.middleware.pipelines.Pipelines", {}),
("dramatiq.middleware.retries.Retries", {}),
),
)
test = conf.get("test", False) @property
def channel_prefix(self) -> str:
return self.conf.get("channel_prefix", "dramatiq.tasks")
@property
def task_class(self) -> str:
return self.conf["task_class"]
@property
def test(self) -> bool:
return self.conf.get("test", False)

View File

@ -9,7 +9,7 @@ from django_dramatiq_postgres.conf import Conf
class DbConnectionMiddleware(Middleware): class DbConnectionMiddleware(Middleware):
def _close_old_connections(self, *args, **kwargs): def _close_old_connections(self, *args, **kwargs):
if Conf.test: if Conf().test:
return return
close_old_connections() close_old_connections()

View File

@ -8,7 +8,7 @@ from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.conf import Conf from django_dramatiq_postgres.conf import Conf
CHANNEL_PREFIX = f"{Conf.channel_prefix}.tasks" CHANNEL_PREFIX = f"{Conf().channel_prefix}.tasks"
class ChannelIdentifier(StrEnum): class ChannelIdentifier(StrEnum):