diff --git a/authentik/root/settings.py b/authentik/root/settings.py index e431b51559..91ff629d3b 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -122,7 +122,7 @@ TENANT_APPS = [ "authentik.stages.user_login", "authentik.stages.user_logout", "authentik.stages.user_write", - # "authentik.tasks", + "authentik.tasks", "authentik.brands", "authentik.blueprints", "guardian", diff --git a/authentik/tasks/broker.py b/authentik/tasks/broker.py index a744ca9596..c807347bd8 100644 --- a/authentik/tasks/broker.py +++ b/authentik/tasks/broker.py @@ -1,30 +1,29 @@ -import logging -from psycopg.errors import AdminShutdown -import tenacity import functools -from pglock.core import _cast_lock_id -from django.utils import timezone -from random import randint +import logging import time -from django.db.backends.postgresql.base import DatabaseWrapper -from typing import Iterable -from django.db import DEFAULT_DB_ALIAS, DatabaseError, InterfaceError, OperationalError, connections -from queue import Queue, Empty +from collections.abc import Iterable +from queue import Empty, Queue +from random import randint -from django.db.models import QuerySet -from dramatiq.broker import Broker, Consumer, MessageProxy -from dramatiq.message import Message -from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name -from dramatiq.results import Results -from dramatiq.errors import QueueJoinTimeout, ConnectionError import orjson +import tenacity +from django.db import DEFAULT_DB_ALIAS, DatabaseError, InterfaceError, OperationalError, connections +from django.db.backends.postgresql.base import DatabaseWrapper +from django.db.models import QuerySet +from django.utils import timezone +from dramatiq.broker import Broker, Consumer, MessageProxy +from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name +from dramatiq.errors import ConnectionError, QueueJoinTimeout +from dramatiq.message import Message +from dramatiq.results import Results +from pglock.core import _cast_lock_id from psycopg import Notify +from psycopg.errors import AdminShutdown from structlog.stdlib import get_logger -from authentik.tasks.models import Queue +from authentik.tasks.models import Queue as MQueue from authentik.tasks.results import PostgresBackend - LOGGER = get_logger() @@ -61,7 +60,7 @@ class PostgresBroker(Broker): @property def query_set(self) -> QuerySet: - return Queue.objects.using(self.db_alias) + return MQueue.objects.using(self.db_alias) def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer: self.declare_queue(queue_name) @@ -119,7 +118,7 @@ class PostgresBroker(Broker): self.query_set.create( message_id=message.message_id, queue_name=message.queue_name, - state=Queue.State.QUEUED, + state=MQueue.State.QUEUED, message=message.encode(), ) self.emit_after("enqueue", message, delay) @@ -154,7 +153,7 @@ class PostgresBroker(Broker): if ( self.query_set.filter( queue_name=queue_name, - state__in=(Queue.State.QUEUED, Queue.State.CONSUMED), + state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), ) == 0 ): @@ -185,7 +184,7 @@ class _PostgresConsumer(Consumer): @property def query_set(self) -> QuerySet: - return Queue.objects.using(self.db_alias) + return MQueue.objects.using(self.db_alias) @property def listen_connection(self) -> DatabaseWrapper: @@ -207,9 +206,9 @@ class _PostgresConsumer(Consumer): self.query_set.filter( message_id=message.message_id, queue_name=message.queue_name, - state=Queue.State.CONSUMED, + state=MQueue.State.CONSUMED, ).update( - state=Queue.State.DONE, + state=MQueue.State.DONE, message=message.encode(), ) self.in_processing.remove(message.message_id) @@ -220,9 +219,9 @@ class _PostgresConsumer(Consumer): self.query_set.filter( message_id=message.message_id, queue_name=message.queue_name, - state__ne=Queue.State.REJECTED, + state__ne=MQueue.State.REJECTED, ).update( - state=Queue.State.REJECT, + state=MQueue.State.REJECT, message=message.encode(), ) self.in_processing.remove(message.message_id) @@ -232,14 +231,14 @@ class _PostgresConsumer(Consumer): self.query_set.filter( message_id__in=[message.message_id for message in messages], ).update( - state=Queue.State.QUEUED, + state=MQueue.State.QUEUED, ) # We don't care about locks, requeue occurs on worker stop def _fetch_pending_notifies(self) -> list[Notify]: self.logger.debug(f"Polling for lost messages in {self.queue_name}") notifies = self.query_set.filter( - state__in=(Queue.State.QUEUED, Queue.State.CONSUMED), queue_name=self.queue_name + state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), queue_name=self.queue_name ) channel = channel_name(self.connection, self.queue_name, "enqueue") return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies] @@ -252,7 +251,7 @@ class _PostgresConsumer(Consumer): def _get_message_lock_id(self, message: Message) -> int: return _cast_lock_id( - f"{channel_name(connections[self.connection], self.queue_name, 'lock')}.{message.message_id}" + f"{channel_name(connections[self.connection], self.queue_name, 'lock')}.{message.message_id}" # noqa: E501 ) def _consume_one(self, message: Message) -> bool: @@ -262,9 +261,10 @@ class _PostgresConsumer(Consumer): result = ( self.query_set.filter( - message_id=message.message_id, state__in=(Queue.State.QUEUED, Queue.State.CONSUMED) + message_id=message.message_id, + state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), ) - .update(state=Queue.State.CONSUMED, mtime=timezone.now()) + .update(state=MQueue.State.CONSUMED, mtime=timezone.now()) .extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)]) ) return result == 1 @@ -276,7 +276,7 @@ class _PostgresConsumer(Consumer): # If we don't have a connection yet, fetch missed notifications from the table directly if self._listen_connection is None: # We might miss a notification between the initial query and the first time we wait for - # notitications, it doesn't matter because we re-fetch for missed messages later on. + # notifications, it doesn't matter because we re-fetch for missed messages later on. self.notifies = self._fetch_pending_notifies() self.logger.debug( f"Found {len(self.notifies)} pending messages in queue {self.queue_name}" @@ -340,7 +340,7 @@ class _PostgresConsumer(Consumer): return self.logger.debug("Running garbage collector") count = self.query_set.filter( - state__in=(Queue.State.DONE, Queue.State.REJECTED), + state__in=(MQueue.State.DONE, MQueue.State.REJECTED), mtime__lte=timezone.now() - timezone.timedelta(days=30), ).delete() self.logger.info(f"Purged {count} messages in all queues") diff --git a/authentik/tasks/migrations/0001_initial.py b/authentik/tasks/migrations/0001_initial.py index b60085694e..c182cf0c2d 100644 --- a/authentik/tasks/migrations/0001_initial.py +++ b/authentik/tasks/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 5.0.12 on 2025-03-08 15:35 +# Generated by Django 5.0.12 on 2025-03-08 22:39 import django.utils.timezone import uuid @@ -37,7 +37,7 @@ class Migration(migrations.Migration): ("mtime", models.DateTimeField(default=django.utils.timezone.now)), ("message", models.JSONField(blank=True, null=True)), ("result", models.JSONField(blank=True, null=True)), - ("result_ttl", models.DateTimeField()), + ("result_ttl", models.DateTimeField(blank=True, null=True)), ], options={ "indexes": [ @@ -45,5 +45,7 @@ class Migration(migrations.Migration): ], }, ), - migrations.RunSQL("ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop), + migrations.RunSQL( + "ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop + ), ] diff --git a/authentik/tasks/models.py b/authentik/tasks/models.py index bd7951616a..59c93a2f4b 100644 --- a/authentik/tasks/models.py +++ b/authentik/tasks/models.py @@ -1,4 +1,5 @@ from uuid import uuid4 + from django.db import models from django.utils import timezone @@ -19,7 +20,7 @@ class Queue(models.Model): result_ttl = models.DateTimeField(blank=True, null=True) class Meta: - indexes = ( - models.Index(fields=("state", "mtime")), - models.Index(fields=("mesage__actor_name",)), - ) + indexes = (models.Index(fields=("state", "mtime")),) + + def __str__(self): + return str(self.message_id) diff --git a/authentik/tasks/results.py b/authentik/tasks/results.py index be675a399a..8e014fd396 100644 --- a/authentik/tasks/results.py +++ b/authentik/tasks/results.py @@ -2,7 +2,7 @@ from django.db import DEFAULT_DB_ALIAS from django.db.models import QuerySet from django.utils import timezone from dramatiq.message import Message, get_encoder -from dramatiq.results.backend import MResult, Missing, Result, ResultBackend +from dramatiq.results.backend import Missing, MResult, Result, ResultBackend from authentik.tasks.models import Queue diff --git a/pyproject.toml b/pyproject.toml index 861867e1dd..ca736cc8ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ skip = [ "./gen-go-api", "*.api.mdx", "./htmlcov", + "./web/custom-elements.json", ] dictionary = ".github/codespell-dictionary.txt,-" ignore-words = ".github/codespell-words.txt"