move middlewares to package

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-19 14:57:17 +02:00
parent 8980282a02
commit 5a5176e21f
20 changed files with 141 additions and 116 deletions

View File

@ -23,6 +23,7 @@ from django.utils.module_loading import import_string
from dramatiq.broker import Broker, Consumer, MessageProxy
from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
from dramatiq.errors import ConnectionError, QueueJoinTimeout
from dramatiq.logging import get_logger
from dramatiq.message import Message
from dramatiq.middleware import (
Middleware,
@ -30,12 +31,11 @@ from dramatiq.middleware import (
from pglock.core import _cast_lock_id
from psycopg import Notify, sql
from psycopg.errors import AdminShutdown
from structlog.stdlib import get_logger
from django_dramatiq_postgres.conf import Conf
from django_dramatiq_postgres.models import CHANNEL_PREFIX, ChannelIdentifier, TaskBase, TaskState
LOGGER = get_logger()
logger = get_logger(__name__)
def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
@ -62,7 +62,7 @@ class PostgresBroker(Broker):
**kwargs,
):
super().__init__(*args, middleware=[], **kwargs)
self.logger = get_logger().bind()
self.logger = get_logger(__name__, type(self))
self.queues = set()
@ -131,7 +131,7 @@ class PostgresBroker(Broker):
reraise=True,
wait=tenacity.wait_random_exponential(multiplier=1, max=30),
stop=tenacity.stop_after_attempt(10),
before_sleep=tenacity.before_sleep_log(LOGGER, logging.INFO, exc_info=True),
before_sleep=tenacity.before_sleep_log(logger, logging.INFO, exc_info=True),
)
def enqueue(self, message: Message, *, delay: int | None = None) -> Message:
canonical_queue_name = message.queue_name
@ -148,20 +148,26 @@ class PostgresBroker(Broker):
self.declare_queue(canonical_queue_name)
self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}")
message.options["model_defaults"] = self.model_defaults(message)
self.emit_before("enqueue", message, delay)
query = {
"message_id": message.message_id,
}
defaults = self.model_defaults(message)
defaults = message.options["model_defaults"]
del message.options["model_defaults"]
create_defaults = {
**query,
**defaults,
}
self.query_set.update_or_create(
**query,
defaults=defaults,
create_defaults=create_defaults,
)
self.emit_after("enqueue", message, delay)
return message
@ -209,7 +215,7 @@ class _PostgresConsumer(Consumer):
timeout: int,
**kwargs,
):
self.logger = get_logger().bind()
self.logger = get_logger(__name__, type(self))
self.notifies: list[Notify] = []
self.broker = broker

View File

@ -1,10 +1,18 @@
import contextvars
from typing import Any
from django.db import (
close_old_connections,
connections,
)
from dramatiq.actor import Actor
from dramatiq.broker import Broker
from dramatiq.logging import get_logger
from dramatiq.message import Message
from dramatiq.middleware.middleware import Middleware
from django_dramatiq_postgres.conf import Conf
from django_dramatiq_postgres.models import TaskBase
class DbConnectionMiddleware(Middleware):
@ -22,3 +30,49 @@ class DbConnectionMiddleware(Middleware):
before_consumer_thread_shutdown = _close_connections
before_worker_thread_shutdown = _close_connections
before_worker_shutdown = _close_connections
class FullyQualifiedActorName(Middleware):
def before_declare_actor(self, broker: Broker, actor: Actor):
actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}"
class CurrentTask(Middleware):
def __init__(self):
self.logger = get_logger(__name__, type(self))
# This is a list of tasks, so that in tests, when a task calls another task, this acts as a pile
_TASKS: contextvars.ContextVar[list[TaskBase] | None] = contextvars.ContextVar(
"_TASKS",
default=None,
)
@classmethod
def get_task(cls) -> TaskBase:
task = cls._TASKS.get()
if not task:
raise RuntimeError("CurrentTask.get_task() can only be called in a running task")
return task[-1]
def before_process_message(self, broker: Broker, message: Message):
tasks = self._TASKS.get()
if tasks is None:
tasks = []
tasks.append(message.options["task"])
self._TASKS.set(tasks)
def after_process_message(
self,
broker: Broker,
message: Message,
*,
result: Any | None = None,
exception: Exception | None = None,
):
tasks: list[TaskBase] | None = self._TASKS.get()
if tasks is None or len(tasks) == 0:
self.logger.warning("Task was None, not saving. This should not happen.")
return
else:
tasks[-1].save()
self._TASKS.set(tasks[:-1])