Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-03-08 23:43:10 +01:00
parent f8c3b64274
commit de54404ab7
6 changed files with 46 additions and 42 deletions

View File

@ -122,7 +122,7 @@ TENANT_APPS = [
"authentik.stages.user_login", "authentik.stages.user_login",
"authentik.stages.user_logout", "authentik.stages.user_logout",
"authentik.stages.user_write", "authentik.stages.user_write",
# "authentik.tasks", "authentik.tasks",
"authentik.brands", "authentik.brands",
"authentik.blueprints", "authentik.blueprints",
"guardian", "guardian",

View File

@ -1,30 +1,29 @@
import logging
from psycopg.errors import AdminShutdown
import tenacity
import functools import functools
from pglock.core import _cast_lock_id import logging
from django.utils import timezone
from random import randint
import time import time
from django.db.backends.postgresql.base import DatabaseWrapper from collections.abc import Iterable
from typing import Iterable from queue import Empty, Queue
from django.db import DEFAULT_DB_ALIAS, DatabaseError, InterfaceError, OperationalError, connections from random import randint
from queue import Queue, Empty
from django.db.models import QuerySet
from dramatiq.broker import Broker, Consumer, MessageProxy
from dramatiq.message import Message
from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
from dramatiq.results import Results
from dramatiq.errors import QueueJoinTimeout, ConnectionError
import orjson import orjson
import tenacity
from django.db import DEFAULT_DB_ALIAS, DatabaseError, InterfaceError, OperationalError, connections
from django.db.backends.postgresql.base import DatabaseWrapper
from django.db.models import QuerySet
from django.utils import timezone
from dramatiq.broker import Broker, Consumer, MessageProxy
from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
from dramatiq.errors import ConnectionError, QueueJoinTimeout
from dramatiq.message import Message
from dramatiq.results import Results
from pglock.core import _cast_lock_id
from psycopg import Notify from psycopg import Notify
from psycopg.errors import AdminShutdown
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.tasks.models import Queue from authentik.tasks.models import Queue as MQueue
from authentik.tasks.results import PostgresBackend from authentik.tasks.results import PostgresBackend
LOGGER = get_logger() LOGGER = get_logger()
@ -61,7 +60,7 @@ class PostgresBroker(Broker):
@property @property
def query_set(self) -> QuerySet: def query_set(self) -> QuerySet:
return Queue.objects.using(self.db_alias) return MQueue.objects.using(self.db_alias)
def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer: def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer:
self.declare_queue(queue_name) self.declare_queue(queue_name)
@ -119,7 +118,7 @@ class PostgresBroker(Broker):
self.query_set.create( self.query_set.create(
message_id=message.message_id, message_id=message.message_id,
queue_name=message.queue_name, queue_name=message.queue_name,
state=Queue.State.QUEUED, state=MQueue.State.QUEUED,
message=message.encode(), message=message.encode(),
) )
self.emit_after("enqueue", message, delay) self.emit_after("enqueue", message, delay)
@ -154,7 +153,7 @@ class PostgresBroker(Broker):
if ( if (
self.query_set.filter( self.query_set.filter(
queue_name=queue_name, queue_name=queue_name,
state__in=(Queue.State.QUEUED, Queue.State.CONSUMED), state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED),
) )
== 0 == 0
): ):
@ -185,7 +184,7 @@ class _PostgresConsumer(Consumer):
@property @property
def query_set(self) -> QuerySet: def query_set(self) -> QuerySet:
return Queue.objects.using(self.db_alias) return MQueue.objects.using(self.db_alias)
@property @property
def listen_connection(self) -> DatabaseWrapper: def listen_connection(self) -> DatabaseWrapper:
@ -207,9 +206,9 @@ class _PostgresConsumer(Consumer):
self.query_set.filter( self.query_set.filter(
message_id=message.message_id, message_id=message.message_id,
queue_name=message.queue_name, queue_name=message.queue_name,
state=Queue.State.CONSUMED, state=MQueue.State.CONSUMED,
).update( ).update(
state=Queue.State.DONE, state=MQueue.State.DONE,
message=message.encode(), message=message.encode(),
) )
self.in_processing.remove(message.message_id) self.in_processing.remove(message.message_id)
@ -220,9 +219,9 @@ class _PostgresConsumer(Consumer):
self.query_set.filter( self.query_set.filter(
message_id=message.message_id, message_id=message.message_id,
queue_name=message.queue_name, queue_name=message.queue_name,
state__ne=Queue.State.REJECTED, state__ne=MQueue.State.REJECTED,
).update( ).update(
state=Queue.State.REJECT, state=MQueue.State.REJECT,
message=message.encode(), message=message.encode(),
) )
self.in_processing.remove(message.message_id) self.in_processing.remove(message.message_id)
@ -232,14 +231,14 @@ class _PostgresConsumer(Consumer):
self.query_set.filter( self.query_set.filter(
message_id__in=[message.message_id for message in messages], message_id__in=[message.message_id for message in messages],
).update( ).update(
state=Queue.State.QUEUED, state=MQueue.State.QUEUED,
) )
# We don't care about locks, requeue occurs on worker stop # We don't care about locks, requeue occurs on worker stop
def _fetch_pending_notifies(self) -> list[Notify]: def _fetch_pending_notifies(self) -> list[Notify]:
self.logger.debug(f"Polling for lost messages in {self.queue_name}") self.logger.debug(f"Polling for lost messages in {self.queue_name}")
notifies = self.query_set.filter( notifies = self.query_set.filter(
state__in=(Queue.State.QUEUED, Queue.State.CONSUMED), queue_name=self.queue_name state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), queue_name=self.queue_name
) )
channel = channel_name(self.connection, self.queue_name, "enqueue") channel = channel_name(self.connection, self.queue_name, "enqueue")
return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies] return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies]
@ -252,7 +251,7 @@ class _PostgresConsumer(Consumer):
def _get_message_lock_id(self, message: Message) -> int: def _get_message_lock_id(self, message: Message) -> int:
return _cast_lock_id( return _cast_lock_id(
f"{channel_name(connections[self.connection], self.queue_name, 'lock')}.{message.message_id}" f"{channel_name(connections[self.connection], self.queue_name, 'lock')}.{message.message_id}" # noqa: E501
) )
def _consume_one(self, message: Message) -> bool: def _consume_one(self, message: Message) -> bool:
@ -262,9 +261,10 @@ class _PostgresConsumer(Consumer):
result = ( result = (
self.query_set.filter( self.query_set.filter(
message_id=message.message_id, state__in=(Queue.State.QUEUED, Queue.State.CONSUMED) message_id=message.message_id,
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED),
) )
.update(state=Queue.State.CONSUMED, mtime=timezone.now()) .update(state=MQueue.State.CONSUMED, mtime=timezone.now())
.extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)]) .extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)])
) )
return result == 1 return result == 1
@ -276,7 +276,7 @@ class _PostgresConsumer(Consumer):
# If we don't have a connection yet, fetch missed notifications from the table directly # If we don't have a connection yet, fetch missed notifications from the table directly
if self._listen_connection is None: if self._listen_connection is None:
# We might miss a notification between the initial query and the first time we wait for # We might miss a notification between the initial query and the first time we wait for
# notitications, it doesn't matter because we re-fetch for missed messages later on. # notifications, it doesn't matter because we re-fetch for missed messages later on.
self.notifies = self._fetch_pending_notifies() self.notifies = self._fetch_pending_notifies()
self.logger.debug( self.logger.debug(
f"Found {len(self.notifies)} pending messages in queue {self.queue_name}" f"Found {len(self.notifies)} pending messages in queue {self.queue_name}"
@ -340,7 +340,7 @@ class _PostgresConsumer(Consumer):
return return
self.logger.debug("Running garbage collector") self.logger.debug("Running garbage collector")
count = self.query_set.filter( count = self.query_set.filter(
state__in=(Queue.State.DONE, Queue.State.REJECTED), state__in=(MQueue.State.DONE, MQueue.State.REJECTED),
mtime__lte=timezone.now() - timezone.timedelta(days=30), mtime__lte=timezone.now() - timezone.timedelta(days=30),
).delete() ).delete()
self.logger.info(f"Purged {count} messages in all queues") self.logger.info(f"Purged {count} messages in all queues")

View File

@ -1,4 +1,4 @@
# Generated by Django 5.0.12 on 2025-03-08 15:35 # Generated by Django 5.0.12 on 2025-03-08 22:39
import django.utils.timezone import django.utils.timezone
import uuid import uuid
@ -37,7 +37,7 @@ class Migration(migrations.Migration):
("mtime", models.DateTimeField(default=django.utils.timezone.now)), ("mtime", models.DateTimeField(default=django.utils.timezone.now)),
("message", models.JSONField(blank=True, null=True)), ("message", models.JSONField(blank=True, null=True)),
("result", models.JSONField(blank=True, null=True)), ("result", models.JSONField(blank=True, null=True)),
("result_ttl", models.DateTimeField()), ("result_ttl", models.DateTimeField(blank=True, null=True)),
], ],
options={ options={
"indexes": [ "indexes": [
@ -45,5 +45,7 @@ class Migration(migrations.Migration):
], ],
}, },
), ),
migrations.RunSQL("ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop), migrations.RunSQL(
"ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop
),
] ]

View File

@ -1,4 +1,5 @@
from uuid import uuid4 from uuid import uuid4
from django.db import models from django.db import models
from django.utils import timezone from django.utils import timezone
@ -19,7 +20,7 @@ class Queue(models.Model):
result_ttl = models.DateTimeField(blank=True, null=True) result_ttl = models.DateTimeField(blank=True, null=True)
class Meta: class Meta:
indexes = ( indexes = (models.Index(fields=("state", "mtime")),)
models.Index(fields=("state", "mtime")),
models.Index(fields=("mesage__actor_name",)), def __str__(self):
) return str(self.message_id)

View File

@ -2,7 +2,7 @@ from django.db import DEFAULT_DB_ALIAS
from django.db.models import QuerySet from django.db.models import QuerySet
from django.utils import timezone from django.utils import timezone
from dramatiq.message import Message, get_encoder from dramatiq.message import Message, get_encoder
from dramatiq.results.backend import MResult, Missing, Result, ResultBackend from dramatiq.results.backend import Missing, MResult, Result, ResultBackend
from authentik.tasks.models import Queue from authentik.tasks.models import Queue

View File

@ -26,6 +26,7 @@ skip = [
"./gen-go-api", "./gen-go-api",
"*.api.mdx", "*.api.mdx",
"./htmlcov", "./htmlcov",
"./web/custom-elements.json",
] ]
dictionary = ".github/codespell-dictionary.txt,-" dictionary = ".github/codespell-dictionary.txt,-"
ignore-words = ".github/codespell-words.txt" ignore-words = ".github/codespell-words.txt"