start moveing stuff to package. Check previous commit for changes to forwardport

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-12 18:37:50 +02:00
parent 1a6ab7f24b
commit e0dcade9ad
6 changed files with 528 additions and 4 deletions

View File

@ -8,7 +8,7 @@ from django_dramatiq_postgres.conf import Conf
class DjangoDramatiqPostgres(AppConfig):
name = "django_dramatiq_postgres"
verbose_name = "Django DramatiQ postgres"
verbose_name = "Django Dramatiq postgres"
def ready(self):
old_broker = dramatiq.get_broker()
@ -28,12 +28,13 @@ class DjangoDramatiqPostgres(AppConfig):
"middleware": [],
}
broker: dramatiq.broker.Broker = import_string(Conf.broker_class)(
*broker_args, **broker_kwargs
*broker_args,
**broker_kwargs,
)
for middleware_class, middleware_kwargs in Conf.middlewares.items():
middleware: dramatiq.middleware.middleware.Middleware = import_string(middleware_class)(
**middleware_kwargs
**middleware_kwargs,
)
broker.add_middleware(middleware)

View File

@ -0,0 +1,425 @@
import functools
import logging
import time
from collections.abc import Iterable
from queue import Empty, Queue
from random import randint
from typing import Any
from django.utils.functional import cached_property
from django.utils.module_loading import import_string
import tenacity
from django.db import (
DEFAULT_DB_ALIAS,
DatabaseError,
InterfaceError,
OperationalError,
connections,
)
from django.db.backends.postgresql.base import DatabaseWrapper
from django.db.models import QuerySet
from django.utils import timezone
from dramatiq.broker import Broker, Consumer, MessageProxy
from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
from dramatiq.errors import ConnectionError, QueueJoinTimeout
from dramatiq.message import Message
from dramatiq.middleware import (
AgeLimit,
Callbacks,
Middleware,
Pipelines,
Prometheus,
Retries,
ShutdownNotifications,
TimeLimit,
)
from pglock.core import _cast_lock_id
from psycopg import Notify, sql
from psycopg.errors import AdminShutdown
from structlog.stdlib import get_logger
from django_dramatiq_postgres.conf import Conf
from django_dramatiq_postgres.middleware import DbConnectionMiddleware
from django_dramatiq_postgres.models import Task, ChannelIdentifier, TaskState, CHANNEL_PREFIX
LOGGER = get_logger()
def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
return f"{CHANNEL_PREFIX}.{queue_name}.{identifier.value}"
def raise_connection_error(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except OperationalError as exc:
raise ConnectionError(str(exc)) from exc
return wrapper
class PostgresBroker(Broker):
def __init__(
self,
*args,
middleware: list[Middleware] | None = None,
db_alias: str = DEFAULT_DB_ALIAS,
**kwargs,
):
super().__init__(*args, middleware=[], **kwargs)
self.logger = get_logger().bind()
self.queues = set()
self.db_alias = db_alias
self.middleware = []
if middleware is None:
for m in (
DbConnectionMiddleware,
Prometheus,
AgeLimit,
TimeLimit,
ShutdownNotifications,
Callbacks,
Pipelines,
Retries,
):
self.add_middleware(m())
for m in middleware or []:
self.add_middleware(m)
@property
def connection(self) -> DatabaseWrapper:
return connections[self.db_alias]
@property
def consumer_class(self) -> "type[_PostgresConsumer]":
return _PostgresConsumer
@cached_property
def model(self) -> type[Task]:
return import_string(Conf.task_class)
@property
def query_set(self) -> QuerySet:
return self.model.objects.using(self.db_alias)
def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer:
self.declare_queue(queue_name)
return self.consumer_class(
broker=self,
db_alias=self.db_alias,
queue_name=queue_name,
prefetch=prefetch,
timeout=timeout,
)
def declare_queue(self, queue_name: str):
if queue_name not in self.queues:
self.emit_before("declare_queue", queue_name)
self.queues.add(queue_name)
# Nothing to do, all queues are in the same table
self.emit_after("declare_queue", queue_name)
delayed_name = dq_name(queue_name)
self.delay_queues.add(delayed_name)
self.emit_after("declare_delay_queue", delayed_name)
def model_defaults(self, message: Message) -> dict[str, Any]:
return {
"queue_name": message.queue_name,
"actor_name": message.actor_name,
"state": TaskState.QUEUED,
"message": message.encode(),
}
@tenacity.retry(
retry=tenacity.retry_if_exception_type(
(
AdminShutdown,
InterfaceError,
DatabaseError,
ConnectionError,
OperationalError,
)
),
reraise=True,
wait=tenacity.wait_random_exponential(multiplier=1, max=30),
stop=tenacity.stop_after_attempt(10),
before_sleep=tenacity.before_sleep_log(LOGGER, logging.INFO, exc_info=True),
)
def enqueue(self, message: Message, *, delay: int | None = None) -> Message:
canonical_queue_name = message.queue_name
queue_name = canonical_queue_name
if delay:
queue_name = dq_name(queue_name)
message_eta = current_millis() + delay
message = message.copy(
queue_name=queue_name,
options={
"eta": message_eta,
},
)
self.declare_queue(canonical_queue_name)
self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}")
self.emit_before("enqueue", message, delay)
query = {
"message_id": message.message_id,
}
defaults = self.model_defaults(message)
create_defaults = {
**query,
**defaults,
}
self.query_set.update_or_create(
**query,
defaults=defaults,
create_defaults=create_defaults,
)
self.emit_after("enqueue", message, delay)
return message
def get_declared_queues(self) -> set[str]:
return self.queues.copy()
def flush(self, queue_name: str):
self.query_set.filter(
queue_name__in=(queue_name, dq_name(queue_name), xq_name(queue_name))
).delete()
def flush_all(self):
for queue_name in self.queues:
self.flush(queue_name)
def join(
self,
queue_name: str,
interval: int = 100,
*,
timeout: int | None = None,
):
deadline = timeout and time.monotonic() + timeout / 1000
while True:
if deadline and time.monotonic() >= deadline:
raise QueueJoinTimeout(queue_name)
if self.query_set.filter(
queue_name=queue_name,
state__in=(TaskState.QUEUED, TaskState.CONSUMED),
).exists():
return
time.sleep(interval / 1000)
class _PostgresConsumer(Consumer):
def __init__(
self,
*args,
broker: PostgresBroker,
db_alias: str,
queue_name: str,
prefetch: int,
timeout: int,
**kwargs,
):
self.logger = get_logger().bind()
self.notifies: list[Notify] = []
self.broker = broker
self.db_alias = db_alias
self.queue_name = queue_name
self.timeout = timeout // 1000
self.unlock_queue = Queue()
self.in_processing = set()
self.prefetch = prefetch
self.misses = 0
self._listen_connection: DatabaseWrapper | None = None
self.postgres_channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE)
@property
def connection(self) -> DatabaseWrapper:
return connections[self.db_alias]
@property
def query_set(self) -> QuerySet:
return self.broker.query_set
@property
def listen_connection(self) -> DatabaseWrapper:
if self._listen_connection is not None and self._listen_connection.connection is not None:
return self._listen_connection
self._listen_connection = connections[self.db_alias]
# Required for notifications
# See https://www.psycopg.org/psycopg3/docs/advanced/async.html#asynchronous-notifications
# Should be set to True by Django by default
self._listen_connection.set_autocommit(True)
with self._listen_connection.cursor() as cursor:
cursor.execute(sql.SQL("LISTEN {}").format(sql.Identifier(self.postgres_channel)))
return self._listen_connection
@raise_connection_error
def ack(self, message: Message):
self.unlock_queue.put_nowait(message)
self.query_set.filter(
message_id=message.message_id,
queue_name=message.queue_name,
state=TaskState.CONSUMED,
).update(
state=TaskState.DONE,
message=message.encode(),
)
self.in_processing.remove(message.message_id)
@raise_connection_error
def nack(self, message: Message):
self.unlock_queue.put_nowait(message)
self.query_set.filter(
message_id=message.message_id,
queue_name=message.queue_name,
).exclude(
state=TaskState.REJECTED,
).update(
state=TaskState.REJECTED,
message=message.encode(),
)
self.in_processing.remove(message.message_id)
@raise_connection_error
def requeue(self, messages: Iterable[Message]):
for message in messages:
self.unlock_queue.put_nowait(message)
self.query_set.filter(
message_id__in=[message.message_id for message in messages],
).update(
state=TaskState.QUEUED,
)
for message in messages:
self.in_processing.remove(message.message_id)
self._purge_locks()
def _fetch_pending_notifies(self) -> list[Notify]:
self.logger.debug(f"Polling for lost messages in {self.queue_name}")
notifies = (
self.query_set.filter(
state__in=(TaskState.QUEUED, TaskState.CONSUMED),
queue_name=self.queue_name,
)
.exclude(
message_id__in=self.in_processing,
)
.values_list("message_id", flat=True)
)
channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE)
return [Notify(pid=0, channel=channel, payload=item) for item in notifies]
def _poll_for_notify(self):
with self.listen_connection.cursor() as cursor:
self.logger.debug(f"timeout is {self.timeout}")
notifies = list(cursor.connection.notifies(timeout=self.timeout))
self.logger.debug(
f"Received {len(notifies)} postgres notifies on channel {self.postgres_channel}"
)
self.notifies += notifies
def _get_message_lock_id(self, message: Message) -> int:
return _cast_lock_id(
f"{channel_name(self.queue_name, ChannelIdentifier.LOCK)}.{message.message_id}"
)
def _consume_one(self, message: Message) -> bool:
if message.message_id in self.in_processing:
self.logger.debug(f"Message {message.message_id} already consumed by self")
return False
result = (
self.query_set.filter(
message_id=message.message_id,
state__in=(TaskState.QUEUED, TaskState.CONSUMED),
)
.extra(
where=["pg_try_advisory_lock(%s)"],
params=[self._get_message_lock_id(message)],
)
.update(
state=TaskState.CONSUMED,
mtime=timezone.now(),
)
)
return result == 1
@raise_connection_error
def __next__(self):
# This method is called every second
# If we don't have a connection yet, fetch missed notifications from the table directly
if self._listen_connection is None:
# We might miss a notification between the initial query and the first time we wait for
# notifications, it doesn't matter because we re-fetch for missed messages later on.
self.notifies = self._fetch_pending_notifies()
self.logger.debug(
f"Found {len(self.notifies)} pending messages in queue {self.queue_name}"
)
processing = len(self.in_processing)
if processing >= self.prefetch:
# Wait and don't consume the message, other worker will be faster
self.misses, backoff_ms = compute_backoff(self.misses, max_backoff=1000)
self.logger.debug(
f"Too many messages in processing: {processing}. Sleeping {backoff_ms} ms"
)
time.sleep(backoff_ms / 1000)
return None
if not self.notifies:
self._poll_for_notify()
if not self.notifies and not randint(0, 300): # nosec
# If there aren't any more notifies, randomly poll for missed/crashed messages.
# Since this method is called every second, this condition limits polling to
# on average one SELECT every five minutes of inactivity.
self.notifies[:] = self._fetch_pending_notifies()
# If we have some notifies, loop to find one to do
while self.notifies:
notify = self.notifies.pop(0)
task = self.query_set.get(message_id=notify.payload)
message = Message.decode(task.message)
if self._consume_one(message):
self.in_processing.add(message.message_id)
return MessageProxy(message)
else:
self.logger.debug(f"Message {message.message_id} already consumed. Skipping.")
# No message to process
self._purge_locks()
self._auto_purge()
def _purge_locks(self):
while True:
try:
message = self.unlock_queue.get(block=False)
except Empty:
return
self.logger.debug(f"Unlocking {message.message_id}@{message.queue_name}")
with self.connection.cursor() as cursor:
cursor.execute(
"SELECT pg_advisory_unlock(%s)", (self._get_message_lock_id(message),)
)
self.unlock_queue.task_done()
def _auto_purge(self):
# Automatically purge messages on average every 100k iteration.
# Dramatiq defaults to 1s, so this means one purge every 28 hours.
if randint(0, 100_000): # nosec
return
self.logger.debug("Running garbage collector")
count = self.query_set.filter(
state__in=(TaskState.DONE, TaskState.REJECTED),
mtime__lte=timezone.now() - timezone.timedelta(days=30),
result_expiry__lte=timezone.now(),
).delete()
self.logger.info(f"Purged {count} messages in all queues")

View File

@ -16,6 +16,7 @@ class Conf:
middlewares = conf.get(
"middlewares",
(
("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}),
("dramatiq.middleware.age_limit.AgeLimit", {}),
("dramatiq.middleware.time_limit.TimeLimit", {}),
("dramatiq.middleware.shutdown.ShutdownNotifications", {}),
@ -24,3 +25,9 @@ class Conf:
("dramatiq.middleware.retries.Retries", {}),
),
)
channel_prefix = conf.get("channel_prefix", "dramatiq.tasks")
task_class = conf.get("task_class", None)
test = conf.get("test", False)

View File

@ -0,0 +1,24 @@
from django.db import (
close_old_connections,
connections,
)
from dramatiq.middleware.middleware import Middleware
from django_dramatiq_postgres.conf import Conf
class DbConnectionMiddleware(Middleware):
def _close_old_connections(self, *args, **kwargs):
if Conf.test:
return
close_old_connections()
before_process_message = _close_old_connections
after_process_message = _close_old_connections
def _close_connections(self, *args, **kwargs):
connections.close_all()
before_consumer_thread_shutdown = _close_connections
before_worker_thread_shutdown = _close_connections
before_worker_shutdown = _close_connections

View File

@ -0,0 +1,67 @@
from enum import StrEnum, auto
from uuid import uuid4
import pgtrigger
from django.db import models
from django.utils import timezone
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.conf import Conf
CHANNEL_PREFIX = f"{Conf.channel_prefix}.tasks"
class ChannelIdentifier(StrEnum):
ENQUEUE = auto()
LOCK = auto()
class TaskState(models.TextChoices):
"""Task system-state. Reported by the task runners"""
QUEUED = "queued"
CONSUMED = "consumed"
REJECTED = "rejected"
DONE = "done"
class Task(models.Model):
message_id = models.UUIDField(primary_key=True, default=uuid4)
queue_name = models.TextField(default="default", help_text=_("Queue name"))
actor_name = models.TextField(help_text=_("Dramatiq actor name"))
message = models.BinaryField(null=True, help_text=_("Message body"))
state = models.CharField(
default=TaskState.QUEUED,
choices=TaskState.choices,
help_text=_("Task status"),
)
mtime = models.DateTimeField(default=timezone.now, help_text=_("Task last modified time"))
result = models.BinaryField(null=True, help_text=_("Task result"))
result_expiry = models.DateTimeField(null=True, help_text=_("Result expiry time"))
class Meta:
abstract = True
verbose_name = _("Task")
verbose_name_plural = _("Tasks")
indexes = (models.Index(fields=("state", "mtime")),)
triggers = (
pgtrigger.Trigger(
name="notify_enqueueing",
operation=pgtrigger.Insert | pgtrigger.Update,
when=pgtrigger.After,
condition=pgtrigger.Q(new__state=TaskState.QUEUED),
timing=pgtrigger.Deferred,
func=f"""
PERFORM pg_notify(
'{CHANNEL_PREFIX}.' || NEW.queue_name || '.{ChannelIdentifier.ENQUEUE.value}',
NEW.message_id::text
);
RETURN NEW;
""", # noqa: E501
),
)
def __str__(self):
return str(self.message_id)

View File

@ -1,7 +1,7 @@
[project]
name = "django-dramatiq-postgres"
version = "0.1.0"
description = "Django and DramatiQ integration with postgres-specific features"
description = "Django and Dramatiq integration with postgres-specific features"
requires-python = ">=3.9,<3.14"
readme = "README.md"
license = "MIT"