Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-03-08 23:43:10 +01:00
parent f8c3b64274
commit de54404ab7
6 changed files with 46 additions and 42 deletions

View File

@ -122,7 +122,7 @@ TENANT_APPS = [
"authentik.stages.user_login",
"authentik.stages.user_logout",
"authentik.stages.user_write",
# "authentik.tasks",
"authentik.tasks",
"authentik.brands",
"authentik.blueprints",
"guardian",

View File

@ -1,30 +1,29 @@
import logging
from psycopg.errors import AdminShutdown
import tenacity
import functools
from pglock.core import _cast_lock_id
from django.utils import timezone
from random import randint
import logging
import time
from django.db.backends.postgresql.base import DatabaseWrapper
from typing import Iterable
from django.db import DEFAULT_DB_ALIAS, DatabaseError, InterfaceError, OperationalError, connections
from queue import Queue, Empty
from collections.abc import Iterable
from queue import Empty, Queue
from random import randint
from django.db.models import QuerySet
from dramatiq.broker import Broker, Consumer, MessageProxy
from dramatiq.message import Message
from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
from dramatiq.results import Results
from dramatiq.errors import QueueJoinTimeout, ConnectionError
import orjson
import tenacity
from django.db import DEFAULT_DB_ALIAS, DatabaseError, InterfaceError, OperationalError, connections
from django.db.backends.postgresql.base import DatabaseWrapper
from django.db.models import QuerySet
from django.utils import timezone
from dramatiq.broker import Broker, Consumer, MessageProxy
from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
from dramatiq.errors import ConnectionError, QueueJoinTimeout
from dramatiq.message import Message
from dramatiq.results import Results
from pglock.core import _cast_lock_id
from psycopg import Notify
from psycopg.errors import AdminShutdown
from structlog.stdlib import get_logger
from authentik.tasks.models import Queue
from authentik.tasks.models import Queue as MQueue
from authentik.tasks.results import PostgresBackend
LOGGER = get_logger()
@ -61,7 +60,7 @@ class PostgresBroker(Broker):
@property
def query_set(self) -> QuerySet:
return Queue.objects.using(self.db_alias)
return MQueue.objects.using(self.db_alias)
def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer:
self.declare_queue(queue_name)
@ -119,7 +118,7 @@ class PostgresBroker(Broker):
self.query_set.create(
message_id=message.message_id,
queue_name=message.queue_name,
state=Queue.State.QUEUED,
state=MQueue.State.QUEUED,
message=message.encode(),
)
self.emit_after("enqueue", message, delay)
@ -154,7 +153,7 @@ class PostgresBroker(Broker):
if (
self.query_set.filter(
queue_name=queue_name,
state__in=(Queue.State.QUEUED, Queue.State.CONSUMED),
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED),
)
== 0
):
@ -185,7 +184,7 @@ class _PostgresConsumer(Consumer):
@property
def query_set(self) -> QuerySet:
return Queue.objects.using(self.db_alias)
return MQueue.objects.using(self.db_alias)
@property
def listen_connection(self) -> DatabaseWrapper:
@ -207,9 +206,9 @@ class _PostgresConsumer(Consumer):
self.query_set.filter(
message_id=message.message_id,
queue_name=message.queue_name,
state=Queue.State.CONSUMED,
state=MQueue.State.CONSUMED,
).update(
state=Queue.State.DONE,
state=MQueue.State.DONE,
message=message.encode(),
)
self.in_processing.remove(message.message_id)
@ -220,9 +219,9 @@ class _PostgresConsumer(Consumer):
self.query_set.filter(
message_id=message.message_id,
queue_name=message.queue_name,
state__ne=Queue.State.REJECTED,
state__ne=MQueue.State.REJECTED,
).update(
state=Queue.State.REJECT,
state=MQueue.State.REJECT,
message=message.encode(),
)
self.in_processing.remove(message.message_id)
@ -232,14 +231,14 @@ class _PostgresConsumer(Consumer):
self.query_set.filter(
message_id__in=[message.message_id for message in messages],
).update(
state=Queue.State.QUEUED,
state=MQueue.State.QUEUED,
)
# We don't care about locks, requeue occurs on worker stop
def _fetch_pending_notifies(self) -> list[Notify]:
self.logger.debug(f"Polling for lost messages in {self.queue_name}")
notifies = self.query_set.filter(
state__in=(Queue.State.QUEUED, Queue.State.CONSUMED), queue_name=self.queue_name
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), queue_name=self.queue_name
)
channel = channel_name(self.connection, self.queue_name, "enqueue")
return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies]
@ -252,7 +251,7 @@ class _PostgresConsumer(Consumer):
def _get_message_lock_id(self, message: Message) -> int:
return _cast_lock_id(
f"{channel_name(connections[self.connection], self.queue_name, 'lock')}.{message.message_id}"
f"{channel_name(connections[self.connection], self.queue_name, 'lock')}.{message.message_id}" # noqa: E501
)
def _consume_one(self, message: Message) -> bool:
@ -262,9 +261,10 @@ class _PostgresConsumer(Consumer):
result = (
self.query_set.filter(
message_id=message.message_id, state__in=(Queue.State.QUEUED, Queue.State.CONSUMED)
message_id=message.message_id,
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED),
)
.update(state=Queue.State.CONSUMED, mtime=timezone.now())
.update(state=MQueue.State.CONSUMED, mtime=timezone.now())
.extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)])
)
return result == 1
@ -276,7 +276,7 @@ class _PostgresConsumer(Consumer):
# If we don't have a connection yet, fetch missed notifications from the table directly
if self._listen_connection is None:
# We might miss a notification between the initial query and the first time we wait for
# notitications, it doesn't matter because we re-fetch for missed messages later on.
# notifications, it doesn't matter because we re-fetch for missed messages later on.
self.notifies = self._fetch_pending_notifies()
self.logger.debug(
f"Found {len(self.notifies)} pending messages in queue {self.queue_name}"
@ -340,7 +340,7 @@ class _PostgresConsumer(Consumer):
return
self.logger.debug("Running garbage collector")
count = self.query_set.filter(
state__in=(Queue.State.DONE, Queue.State.REJECTED),
state__in=(MQueue.State.DONE, MQueue.State.REJECTED),
mtime__lte=timezone.now() - timezone.timedelta(days=30),
).delete()
self.logger.info(f"Purged {count} messages in all queues")

View File

@ -1,4 +1,4 @@
# Generated by Django 5.0.12 on 2025-03-08 15:35
# Generated by Django 5.0.12 on 2025-03-08 22:39
import django.utils.timezone
import uuid
@ -37,7 +37,7 @@ class Migration(migrations.Migration):
("mtime", models.DateTimeField(default=django.utils.timezone.now)),
("message", models.JSONField(blank=True, null=True)),
("result", models.JSONField(blank=True, null=True)),
("result_ttl", models.DateTimeField()),
("result_ttl", models.DateTimeField(blank=True, null=True)),
],
options={
"indexes": [
@ -45,5 +45,7 @@ class Migration(migrations.Migration):
],
},
),
migrations.RunSQL("ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop),
migrations.RunSQL(
"ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop
),
]

View File

@ -1,4 +1,5 @@
from uuid import uuid4
from django.db import models
from django.utils import timezone
@ -19,7 +20,7 @@ class Queue(models.Model):
result_ttl = models.DateTimeField(blank=True, null=True)
class Meta:
indexes = (
models.Index(fields=("state", "mtime")),
models.Index(fields=("mesage__actor_name",)),
)
indexes = (models.Index(fields=("state", "mtime")),)
def __str__(self):
return str(self.message_id)

View File

@ -2,7 +2,7 @@ from django.db import DEFAULT_DB_ALIAS
from django.db.models import QuerySet
from django.utils import timezone
from dramatiq.message import Message, get_encoder
from dramatiq.results.backend import MResult, Missing, Result, ResultBackend
from dramatiq.results.backend import Missing, MResult, Result, ResultBackend
from authentik.tasks.models import Queue

View File

@ -26,6 +26,7 @@ skip = [
"./gen-go-api",
"*.api.mdx",
"./htmlcov",
"./web/custom-elements.json",
]
dictionary = ".github/codespell-dictionary.txt,-"
ignore-words = ".github/codespell-words.txt"