diff --git a/authentik/events/migrations/0010_delete_systemtask.py b/authentik/events/migrations/0011_delete_systemtask.py similarity index 58% rename from authentik/events/migrations/0010_delete_systemtask.py rename to authentik/events/migrations/0011_delete_systemtask.py index 973868c29c..34993eb83e 100644 --- a/authentik/events/migrations/0010_delete_systemtask.py +++ b/authentik/events/migrations/0011_delete_systemtask.py @@ -1,4 +1,4 @@ -# Generated by Django 5.1.9 on 2025-06-06 13:25 +# Generated by Django 5.1.11 on 2025-06-18 12:43 from django.db import migrations @@ -6,7 +6,7 @@ from django.db import migrations class Migration(migrations.Migration): dependencies = [ - ("authentik_events", "0009_remove_notificationtransport_webhook_mapping_and_more"), + ("authentik_events", "0010_rename_group_notificationrule_destination_group_and_more"), ] operations = [ diff --git a/authentik/tasks/broker.py b/authentik/tasks/broker.py index 7a5dbf9f02..793e70b9dd 100644 --- a/authentik/tasks/broker.py +++ b/authentik/tasks/broker.py @@ -1,26 +1,10 @@ -import functools -import logging -import time -from collections.abc import Iterable -from queue import Empty, Queue -from random import randint +from typing import Any -import tenacity -from django.conf import settings +from django_dramatiq_postgres.middleware import DbConnectionMiddleware from django.db import ( DEFAULT_DB_ALIAS, - DatabaseError, - InterfaceError, - OperationalError, - close_old_connections, - connections, ) -from django.db.backends.postgresql.base import DatabaseWrapper -from django.db.models import QuerySet -from django.utils import timezone -from dramatiq.broker import Broker, Consumer, MessageProxy -from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name -from dramatiq.errors import ConnectionError, QueueJoinTimeout +from dramatiq.broker import Broker from dramatiq.message import Message from dramatiq.middleware import ( AgeLimit, @@ -32,51 +16,16 @@ from dramatiq.middleware import ( ShutdownNotifications, TimeLimit, ) -from pglock.core import _cast_lock_id -from psycopg import Notify, sql -from psycopg.errors import AdminShutdown from structlog.stdlib import get_logger +from django_dramatiq_postgres.broker import PostgresBroker as PostgresBrokerBase -from authentik.tasks.models import CHANNEL_PREFIX, ChannelIdentifier, Task, TaskState -from authentik.tasks.schedules.scheduler import Scheduler +from authentik.tasks.models import Task from authentik.tenants.models import Tenant from authentik.tenants.utils import get_current_tenant LOGGER = get_logger() -def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str: - return f"{CHANNEL_PREFIX}.{queue_name}.{identifier.value}" - - -def raise_connection_error(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except OperationalError as exc: - raise ConnectionError(str(exc)) from exc - - return wrapper - - -class DbConnectionMiddleware(Middleware): - def _close_old_connections(self, *args, **kwargs): - if settings.TEST: - return - close_old_connections() - - before_process_message = _close_old_connections - after_process_message = _close_old_connections - - def _close_connections(self, *args, **kwargs): - connections.close_all() - - before_consumer_thread_shutdown = _close_connections - before_worker_thread_shutdown = _close_connections - before_worker_shutdown = _close_connections - - class TenantMiddleware(Middleware): def before_process_message(self, broker: Broker, message: Message): Task.objects.select_related("tenant").get(message_id=message.message_id).tenant.activate() @@ -85,7 +34,7 @@ class TenantMiddleware(Middleware): Tenant.deactivate() -class PostgresBroker(Broker): +class PostgresBroker(PostgresBrokerBase): def __init__( self, *args, @@ -119,340 +68,12 @@ class PostgresBroker(Broker): for m in middleware or []: self.add_middleware(m) - @property - def connection(self) -> DatabaseWrapper: - return connections[self.db_alias] - - @property - def consumer_class(self) -> "type[_PostgresConsumer]": - return _PostgresConsumer - - @property - def query_set(self) -> QuerySet: - return Task.objects.using(self.db_alias) - - def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer: - self.declare_queue(queue_name) - return self.consumer_class( - broker=self, - db_alias=self.db_alias, - queue_name=queue_name, - prefetch=prefetch, - timeout=timeout, - ) - - def declare_queue(self, queue_name: str): - if queue_name not in self.queues: - self.emit_before("declare_queue", queue_name) - self.queues.add(queue_name) - # Nothing to do, all queues are in the same table - self.emit_after("declare_queue", queue_name) - - delayed_name = dq_name(queue_name) - self.delay_queues.add(delayed_name) - self.emit_after("declare_delay_queue", delayed_name) - - @tenacity.retry( - retry=tenacity.retry_if_exception_type( - ( - AdminShutdown, - InterfaceError, - DatabaseError, - ConnectionError, - OperationalError, - ) - ), - reraise=True, - wait=tenacity.wait_random_exponential(multiplier=1, max=30), - stop=tenacity.stop_after_attempt(10), - before_sleep=tenacity.before_sleep_log(LOGGER, logging.INFO, exc_info=True), - ) - def enqueue(self, message: Message, *, delay: int | None = None) -> Message: - canonical_queue_name = message.queue_name - queue_name = canonical_queue_name - if delay: - queue_name = dq_name(queue_name) - message_eta = current_millis() + delay - message = message.copy( - queue_name=queue_name, - options={ - "eta": message_eta, - }, - ) - - self.declare_queue(canonical_queue_name) - self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}") - self.emit_before("enqueue", message, delay) - query = { - "message_id": message.message_id, - } - defaults = { + def model_defaults(self, message: Message) -> dict[str, Any]: + rel_obj = message.options.get("rel_obj") + if rel_obj: + del message.options["rel_obj"] + return { "tenant": get_current_tenant(), - "queue_name": message.queue_name, - "actor_name": message.actor_name, - "state": TaskState.QUEUED, - "message": message.encode(), - "rel_obj": message.options.get("rel_obj", None), + "rel_obj": rel_obj, + **super().model_defaults(message), } - create_defaults = { - **query, - **defaults, - } - self.query_set.update_or_create( - **query, - defaults=defaults, - create_defaults=create_defaults, - ) - self.emit_after("enqueue", message, delay) - return message - - def get_declared_queues(self) -> set[str]: - return self.queues.copy() - - def flush(self, queue_name: str): - self.query_set.filter( - queue_name__in=(queue_name, dq_name(queue_name), xq_name(queue_name)) - ).delete() - - def flush_all(self): - for queue_name in self.queues: - self.flush(queue_name) - - def join( - self, - queue_name: str, - interval: int = 100, - *, - timeout: int | None = None, - ): - deadline = timeout and time.monotonic() + timeout / 1000 - while True: - if deadline and time.monotonic() >= deadline: - raise QueueJoinTimeout(queue_name) - - if self.query_set.filter( - queue_name=queue_name, - state__in=(TaskState.QUEUED, TaskState.CONSUMED), - ).exists(): - return - - time.sleep(interval / 1000) - - -class _PostgresConsumer(Consumer): - def __init__( - self, - *args, - broker: Broker, - db_alias: str, - queue_name: str, - prefetch: int, - timeout: int, - **kwargs, - ): - self.logger = get_logger().bind() - - self.notifies: list[Notify] = [] - self.broker = broker - self.db_alias = db_alias - self.queue_name = queue_name - self.timeout = 30000 // 1000 - self.unlock_queue = Queue() - self.in_processing = set() - self.prefetch = prefetch - self.misses = 0 - self._listen_connection: DatabaseWrapper | None = None - self.postgres_channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE) - - self.scheduler = Scheduler(self.broker) - - @property - def connection(self) -> DatabaseWrapper: - return connections[self.db_alias] - - @property - def query_set(self) -> QuerySet: - return Task.objects.using(self.db_alias) - - @property - def listen_connection(self) -> DatabaseWrapper: - if self._listen_connection is not None and self._listen_connection.connection is not None: - return self._listen_connection - self._listen_connection = connections[self.db_alias] - # Required for notifications - # See https://www.psycopg.org/psycopg3/docs/advanced/async.html#asynchronous-notifications - # Should be set to True by Django by default - self._listen_connection.set_autocommit(True) - with self._listen_connection.cursor() as cursor: - cursor.execute(sql.SQL("LISTEN {}").format(sql.Identifier(self.postgres_channel))) - return self._listen_connection - - @raise_connection_error - def ack(self, message: Message): - self.unlock_queue.put_nowait(message) - self.query_set.filter( - message_id=message.message_id, - queue_name=message.queue_name, - state=TaskState.CONSUMED, - ).update( - state=TaskState.DONE, - message=message.encode(), - ) - self.in_processing.remove(message.message_id) - - @raise_connection_error - def nack(self, message: Message): - self.unlock_queue.put_nowait(message) - self.query_set.filter( - message_id=message.message_id, - queue_name=message.queue_name, - ).exclude( - state=TaskState.REJECTED, - ).update( - state=TaskState.REJECTED, - message=message.encode(), - ) - self.in_processing.remove(message.message_id) - - @raise_connection_error - def requeue(self, messages: Iterable[Message]): - for message in messages: - self.unlock_queue.put_nowait(message) - self.query_set.filter( - message_id__in=[message.message_id for message in messages], - ).update( - state=TaskState.QUEUED, - ) - for message in messages: - self.in_processing.remove(message.message_id) - self._purge_locks() - - def _fetch_pending_notifies(self) -> list[Notify]: - self.logger.debug(f"Polling for lost messages in {self.queue_name}") - notifies = ( - self.query_set.filter( - state__in=(TaskState.QUEUED, TaskState.CONSUMED), - queue_name=self.queue_name, - ) - .exclude( - message_id__in=self.in_processing, - ) - .values_list("message_id", flat=True) - ) - return [Notify(pid=0, channel=self.postgres_channel, payload=item) for item in notifies] - - def _poll_for_notify(self): - with self.listen_connection.cursor() as cursor: - self.logger.debug(f"timeout is {self.timeout}") - notifies = list(cursor.connection.notifies(timeout=self.timeout, stop_after=1)) - self.logger.debug( - f"Received {len(notifies)} postgres notifies on channel {self.postgres_channel}" - ) - self.notifies += notifies - - def _get_message_lock_id(self, message: Message) -> int: - return _cast_lock_id( - f"{channel_name(self.queue_name, ChannelIdentifier.LOCK)}.{message.message_id}" - ) - - def _consume_one(self, message: Message) -> bool: - if message.message_id in self.in_processing: - self.logger.debug(f"Message {message.message_id} already consumed by self") - return False - - result = ( - self.query_set.filter( - message_id=message.message_id, - state__in=(TaskState.QUEUED, TaskState.CONSUMED), - ) - .extra( - where=["pg_try_advisory_lock(%s)"], - params=[self._get_message_lock_id(message)], - ) - .update( - state=TaskState.CONSUMED, - mtime=timezone.now(), - ) - ) - return result == 1 - - @raise_connection_error - def __next__(self): - # This method is called every second - - # If we don't have a connection yet, fetch missed notifications from the table directly - if self._listen_connection is None: - # We might miss a notification between the initial query and the first time we wait for - # notifications, it doesn't matter because we re-fetch for missed messages later on. - self.notifies = self._fetch_pending_notifies() - self.logger.debug( - f"Found {len(self.notifies)} pending messages in queue {self.queue_name}" - ) - - processing = len(self.in_processing) - if processing >= self.prefetch: - # Wait and don't consume the message, other worker will be faster - self.misses, backoff_ms = compute_backoff(self.misses, max_backoff=1000) - self.logger.debug( - f"Too many messages in processing: {processing}. Sleeping {backoff_ms} ms" - ) - time.sleep(backoff_ms / 1000) - return None - - if not self.notifies: - self._poll_for_notify() - - if not self.notifies and not randint(0, 300): # nosec - # If there aren't any more notifies, randomly poll for missed/crashed messages. - # Since this method is called every second, this condition limits polling to - # on average one SELECT every five minutes of inactivity. - self.notifies[:] = self._fetch_pending_notifies() - - # If we have some notifies, loop to find one to do - while self.notifies: - notify = self.notifies.pop(0) - task = self.query_set.get(message_id=notify.payload) - message = Message.decode(task.message) - if self._consume_one(message): - self.in_processing.add(message.message_id) - return MessageProxy(message) - else: - self.logger.debug(f"Message {message.message_id} already consumed. Skipping.") - - # No message to process - self._purge_locks() - self._auto_purge() - self._run_scheduler() - - def _purge_locks(self): - while True: - try: - message = self.unlock_queue.get(block=False) - except Empty: - return - self.logger.debug(f"Unlocking {message.message_id}@{message.queue_name}") - with self.connection.cursor() as cursor: - cursor.execute( - "SELECT pg_advisory_unlock(%s)", (self._get_message_lock_id(message),) - ) - self.unlock_queue.task_done() - - def _auto_purge(self): - # Automatically purge messages on average every 100k iteration. - # Dramatiq defaults to 1s, so this means one purge every 28 hours. - if randint(0, 100_000): # nosec - return - self.logger.debug("Running garbage collector") - count = self.query_set.filter( - state__in=(TaskState.DONE, TaskState.REJECTED), - mtime__lte=timezone.now() - timezone.timedelta(days=30), - result_expiry__lte=timezone.now(), - ).delete() - self.logger.info(f"Purged {count} messages in all queues") - - def _run_scheduler(self): - # Same as above, run on average once every minute - if randint(0, 60): # nosec - return - self.logger.debug("Running scheduler") - self.scheduler.run()