move middlewares to package
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
@ -23,6 +23,7 @@ from django.utils.module_loading import import_string
|
||||
from dramatiq.broker import Broker, Consumer, MessageProxy
|
||||
from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
|
||||
from dramatiq.errors import ConnectionError, QueueJoinTimeout
|
||||
from dramatiq.logging import get_logger
|
||||
from dramatiq.message import Message
|
||||
from dramatiq.middleware import (
|
||||
Middleware,
|
||||
@ -30,12 +31,11 @@ from dramatiq.middleware import (
|
||||
from pglock.core import _cast_lock_id
|
||||
from psycopg import Notify, sql
|
||||
from psycopg.errors import AdminShutdown
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from django_dramatiq_postgres.conf import Conf
|
||||
from django_dramatiq_postgres.models import CHANNEL_PREFIX, ChannelIdentifier, TaskBase, TaskState
|
||||
|
||||
LOGGER = get_logger()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
|
||||
@ -62,7 +62,7 @@ class PostgresBroker(Broker):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, middleware=[], **kwargs)
|
||||
self.logger = get_logger().bind()
|
||||
self.logger = get_logger(__name__, type(self))
|
||||
|
||||
self.queues = set()
|
||||
|
||||
@ -131,7 +131,7 @@ class PostgresBroker(Broker):
|
||||
reraise=True,
|
||||
wait=tenacity.wait_random_exponential(multiplier=1, max=30),
|
||||
stop=tenacity.stop_after_attempt(10),
|
||||
before_sleep=tenacity.before_sleep_log(LOGGER, logging.INFO, exc_info=True),
|
||||
before_sleep=tenacity.before_sleep_log(logger, logging.INFO, exc_info=True),
|
||||
)
|
||||
def enqueue(self, message: Message, *, delay: int | None = None) -> Message:
|
||||
canonical_queue_name = message.queue_name
|
||||
@ -148,20 +148,26 @@ class PostgresBroker(Broker):
|
||||
|
||||
self.declare_queue(canonical_queue_name)
|
||||
self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}")
|
||||
|
||||
message.options["model_defaults"] = self.model_defaults(message)
|
||||
self.emit_before("enqueue", message, delay)
|
||||
|
||||
query = {
|
||||
"message_id": message.message_id,
|
||||
}
|
||||
defaults = self.model_defaults(message)
|
||||
defaults = message.options["model_defaults"]
|
||||
del message.options["model_defaults"]
|
||||
create_defaults = {
|
||||
**query,
|
||||
**defaults,
|
||||
}
|
||||
|
||||
self.query_set.update_or_create(
|
||||
**query,
|
||||
defaults=defaults,
|
||||
create_defaults=create_defaults,
|
||||
)
|
||||
|
||||
self.emit_after("enqueue", message, delay)
|
||||
return message
|
||||
|
||||
@ -209,7 +215,7 @@ class _PostgresConsumer(Consumer):
|
||||
timeout: int,
|
||||
**kwargs,
|
||||
):
|
||||
self.logger = get_logger().bind()
|
||||
self.logger = get_logger(__name__, type(self))
|
||||
|
||||
self.notifies: list[Notify] = []
|
||||
self.broker = broker
|
||||
|
||||
@ -1,10 +1,18 @@
|
||||
import contextvars
|
||||
from typing import Any
|
||||
|
||||
from django.db import (
|
||||
close_old_connections,
|
||||
connections,
|
||||
)
|
||||
from dramatiq.actor import Actor
|
||||
from dramatiq.broker import Broker
|
||||
from dramatiq.logging import get_logger
|
||||
from dramatiq.message import Message
|
||||
from dramatiq.middleware.middleware import Middleware
|
||||
|
||||
from django_dramatiq_postgres.conf import Conf
|
||||
from django_dramatiq_postgres.models import TaskBase
|
||||
|
||||
|
||||
class DbConnectionMiddleware(Middleware):
|
||||
@ -22,3 +30,49 @@ class DbConnectionMiddleware(Middleware):
|
||||
before_consumer_thread_shutdown = _close_connections
|
||||
before_worker_thread_shutdown = _close_connections
|
||||
before_worker_shutdown = _close_connections
|
||||
|
||||
|
||||
class FullyQualifiedActorName(Middleware):
|
||||
def before_declare_actor(self, broker: Broker, actor: Actor):
|
||||
actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}"
|
||||
|
||||
|
||||
class CurrentTask(Middleware):
|
||||
def __init__(self):
|
||||
self.logger = get_logger(__name__, type(self))
|
||||
|
||||
# This is a list of tasks, so that in tests, when a task calls another task, this acts as a pile
|
||||
_TASKS: contextvars.ContextVar[list[TaskBase] | None] = contextvars.ContextVar(
|
||||
"_TASKS",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_task(cls) -> TaskBase:
|
||||
task = cls._TASKS.get()
|
||||
if not task:
|
||||
raise RuntimeError("CurrentTask.get_task() can only be called in a running task")
|
||||
return task[-1]
|
||||
|
||||
def before_process_message(self, broker: Broker, message: Message):
|
||||
tasks = self._TASKS.get()
|
||||
if tasks is None:
|
||||
tasks = []
|
||||
tasks.append(message.options["task"])
|
||||
self._TASKS.set(tasks)
|
||||
|
||||
def after_process_message(
|
||||
self,
|
||||
broker: Broker,
|
||||
message: Message,
|
||||
*,
|
||||
result: Any | None = None,
|
||||
exception: Exception | None = None,
|
||||
):
|
||||
tasks: list[TaskBase] | None = self._TASKS.get()
|
||||
if tasks is None or len(tasks) == 0:
|
||||
self.logger.warning("Task was None, not saving. This should not happen.")
|
||||
return
|
||||
else:
|
||||
tasks[-1].save()
|
||||
self._TASKS.set(tasks[:-1])
|
||||
|
||||
Reference in New Issue
Block a user