|
|
|
|
@ -1,30 +1,29 @@
|
|
|
|
|
import logging
|
|
|
|
|
from psycopg.errors import AdminShutdown
|
|
|
|
|
import tenacity
|
|
|
|
|
import functools
|
|
|
|
|
from pglock.core import _cast_lock_id
|
|
|
|
|
from django.utils import timezone
|
|
|
|
|
from random import randint
|
|
|
|
|
import logging
|
|
|
|
|
import time
|
|
|
|
|
from django.db.backends.postgresql.base import DatabaseWrapper
|
|
|
|
|
from typing import Iterable
|
|
|
|
|
from django.db import DEFAULT_DB_ALIAS, DatabaseError, InterfaceError, OperationalError, connections
|
|
|
|
|
from queue import Queue, Empty
|
|
|
|
|
from collections.abc import Iterable
|
|
|
|
|
from queue import Empty, Queue
|
|
|
|
|
from random import randint
|
|
|
|
|
|
|
|
|
|
from django.db.models import QuerySet
|
|
|
|
|
from dramatiq.broker import Broker, Consumer, MessageProxy
|
|
|
|
|
from dramatiq.message import Message
|
|
|
|
|
from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
|
|
|
|
|
from dramatiq.results import Results
|
|
|
|
|
from dramatiq.errors import QueueJoinTimeout, ConnectionError
|
|
|
|
|
import orjson
|
|
|
|
|
import tenacity
|
|
|
|
|
from django.db import DEFAULT_DB_ALIAS, DatabaseError, InterfaceError, OperationalError, connections
|
|
|
|
|
from django.db.backends.postgresql.base import DatabaseWrapper
|
|
|
|
|
from django.db.models import QuerySet
|
|
|
|
|
from django.utils import timezone
|
|
|
|
|
from dramatiq.broker import Broker, Consumer, MessageProxy
|
|
|
|
|
from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
|
|
|
|
|
from dramatiq.errors import ConnectionError, QueueJoinTimeout
|
|
|
|
|
from dramatiq.message import Message
|
|
|
|
|
from dramatiq.results import Results
|
|
|
|
|
from pglock.core import _cast_lock_id
|
|
|
|
|
from psycopg import Notify
|
|
|
|
|
from psycopg.errors import AdminShutdown
|
|
|
|
|
from structlog.stdlib import get_logger
|
|
|
|
|
|
|
|
|
|
from authentik.tasks.models import Queue
|
|
|
|
|
from authentik.tasks.models import Queue as MQueue
|
|
|
|
|
from authentik.tasks.results import PostgresBackend
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LOGGER = get_logger()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -61,7 +60,7 @@ class PostgresBroker(Broker):
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def query_set(self) -> QuerySet:
|
|
|
|
|
return Queue.objects.using(self.db_alias)
|
|
|
|
|
return MQueue.objects.using(self.db_alias)
|
|
|
|
|
|
|
|
|
|
def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer:
|
|
|
|
|
self.declare_queue(queue_name)
|
|
|
|
|
@ -119,7 +118,7 @@ class PostgresBroker(Broker):
|
|
|
|
|
self.query_set.create(
|
|
|
|
|
message_id=message.message_id,
|
|
|
|
|
queue_name=message.queue_name,
|
|
|
|
|
state=Queue.State.QUEUED,
|
|
|
|
|
state=MQueue.State.QUEUED,
|
|
|
|
|
message=message.encode(),
|
|
|
|
|
)
|
|
|
|
|
self.emit_after("enqueue", message, delay)
|
|
|
|
|
@ -154,7 +153,7 @@ class PostgresBroker(Broker):
|
|
|
|
|
if (
|
|
|
|
|
self.query_set.filter(
|
|
|
|
|
queue_name=queue_name,
|
|
|
|
|
state__in=(Queue.State.QUEUED, Queue.State.CONSUMED),
|
|
|
|
|
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED),
|
|
|
|
|
)
|
|
|
|
|
== 0
|
|
|
|
|
):
|
|
|
|
|
@ -185,7 +184,7 @@ class _PostgresConsumer(Consumer):
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def query_set(self) -> QuerySet:
|
|
|
|
|
return Queue.objects.using(self.db_alias)
|
|
|
|
|
return MQueue.objects.using(self.db_alias)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def listen_connection(self) -> DatabaseWrapper:
|
|
|
|
|
@ -207,9 +206,9 @@ class _PostgresConsumer(Consumer):
|
|
|
|
|
self.query_set.filter(
|
|
|
|
|
message_id=message.message_id,
|
|
|
|
|
queue_name=message.queue_name,
|
|
|
|
|
state=Queue.State.CONSUMED,
|
|
|
|
|
state=MQueue.State.CONSUMED,
|
|
|
|
|
).update(
|
|
|
|
|
state=Queue.State.DONE,
|
|
|
|
|
state=MQueue.State.DONE,
|
|
|
|
|
message=message.encode(),
|
|
|
|
|
)
|
|
|
|
|
self.in_processing.remove(message.message_id)
|
|
|
|
|
@ -220,9 +219,9 @@ class _PostgresConsumer(Consumer):
|
|
|
|
|
self.query_set.filter(
|
|
|
|
|
message_id=message.message_id,
|
|
|
|
|
queue_name=message.queue_name,
|
|
|
|
|
state__ne=Queue.State.REJECTED,
|
|
|
|
|
state__ne=MQueue.State.REJECTED,
|
|
|
|
|
).update(
|
|
|
|
|
state=Queue.State.REJECT,
|
|
|
|
|
state=MQueue.State.REJECT,
|
|
|
|
|
message=message.encode(),
|
|
|
|
|
)
|
|
|
|
|
self.in_processing.remove(message.message_id)
|
|
|
|
|
@ -232,14 +231,14 @@ class _PostgresConsumer(Consumer):
|
|
|
|
|
self.query_set.filter(
|
|
|
|
|
message_id__in=[message.message_id for message in messages],
|
|
|
|
|
).update(
|
|
|
|
|
state=Queue.State.QUEUED,
|
|
|
|
|
state=MQueue.State.QUEUED,
|
|
|
|
|
)
|
|
|
|
|
# We don't care about locks, requeue occurs on worker stop
|
|
|
|
|
|
|
|
|
|
def _fetch_pending_notifies(self) -> list[Notify]:
|
|
|
|
|
self.logger.debug(f"Polling for lost messages in {self.queue_name}")
|
|
|
|
|
notifies = self.query_set.filter(
|
|
|
|
|
state__in=(Queue.State.QUEUED, Queue.State.CONSUMED), queue_name=self.queue_name
|
|
|
|
|
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), queue_name=self.queue_name
|
|
|
|
|
)
|
|
|
|
|
channel = channel_name(self.connection, self.queue_name, "enqueue")
|
|
|
|
|
return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies]
|
|
|
|
|
@ -252,7 +251,7 @@ class _PostgresConsumer(Consumer):
|
|
|
|
|
|
|
|
|
|
def _get_message_lock_id(self, message: Message) -> int:
|
|
|
|
|
return _cast_lock_id(
|
|
|
|
|
f"{channel_name(connections[self.connection], self.queue_name, 'lock')}.{message.message_id}"
|
|
|
|
|
f"{channel_name(connections[self.connection], self.queue_name, 'lock')}.{message.message_id}" # noqa: E501
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _consume_one(self, message: Message) -> bool:
|
|
|
|
|
@ -262,9 +261,10 @@ class _PostgresConsumer(Consumer):
|
|
|
|
|
|
|
|
|
|
result = (
|
|
|
|
|
self.query_set.filter(
|
|
|
|
|
message_id=message.message_id, state__in=(Queue.State.QUEUED, Queue.State.CONSUMED)
|
|
|
|
|
message_id=message.message_id,
|
|
|
|
|
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED),
|
|
|
|
|
)
|
|
|
|
|
.update(state=Queue.State.CONSUMED, mtime=timezone.now())
|
|
|
|
|
.update(state=MQueue.State.CONSUMED, mtime=timezone.now())
|
|
|
|
|
.extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)])
|
|
|
|
|
)
|
|
|
|
|
return result == 1
|
|
|
|
|
@ -276,7 +276,7 @@ class _PostgresConsumer(Consumer):
|
|
|
|
|
# If we don't have a connection yet, fetch missed notifications from the table directly
|
|
|
|
|
if self._listen_connection is None:
|
|
|
|
|
# We might miss a notification between the initial query and the first time we wait for
|
|
|
|
|
# notitications, it doesn't matter because we re-fetch for missed messages later on.
|
|
|
|
|
# notifications, it doesn't matter because we re-fetch for missed messages later on.
|
|
|
|
|
self.notifies = self._fetch_pending_notifies()
|
|
|
|
|
self.logger.debug(
|
|
|
|
|
f"Found {len(self.notifies)} pending messages in queue {self.queue_name}"
|
|
|
|
|
@ -340,7 +340,7 @@ class _PostgresConsumer(Consumer):
|
|
|
|
|
return
|
|
|
|
|
self.logger.debug("Running garbage collector")
|
|
|
|
|
count = self.query_set.filter(
|
|
|
|
|
state__in=(Queue.State.DONE, Queue.State.REJECTED),
|
|
|
|
|
state__in=(MQueue.State.DONE, MQueue.State.REJECTED),
|
|
|
|
|
mtime__lte=timezone.now() - timezone.timedelta(days=30),
|
|
|
|
|
).delete()
|
|
|
|
|
self.logger.info(f"Purged {count} messages in all queues")
|
|
|
|
|
|