move all broker stuff to package, schedule is still todo
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
@ -7,6 +7,7 @@ from random import randint
|
||||
from typing import Any
|
||||
|
||||
import tenacity
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import (
|
||||
DEFAULT_DB_ALIAS,
|
||||
DatabaseError,
|
||||
@ -24,14 +25,7 @@ from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
|
||||
from dramatiq.errors import ConnectionError, QueueJoinTimeout
|
||||
from dramatiq.message import Message
|
||||
from dramatiq.middleware import (
|
||||
AgeLimit,
|
||||
Callbacks,
|
||||
Middleware,
|
||||
Pipelines,
|
||||
Prometheus,
|
||||
Retries,
|
||||
ShutdownNotifications,
|
||||
TimeLimit,
|
||||
)
|
||||
from pglock.core import _cast_lock_id
|
||||
from psycopg import Notify, sql
|
||||
@ -39,7 +33,6 @@ from psycopg.errors import AdminShutdown
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from django_dramatiq_postgres.conf import Conf
|
||||
from django_dramatiq_postgres.middleware import DbConnectionMiddleware
|
||||
from django_dramatiq_postgres.models import CHANNEL_PREFIX, ChannelIdentifier, TaskBase, TaskState
|
||||
|
||||
LOGGER = get_logger()
|
||||
@ -75,20 +68,10 @@ class PostgresBroker(Broker):
|
||||
|
||||
self.db_alias = db_alias
|
||||
self.middleware = []
|
||||
if middleware is None:
|
||||
for m in (
|
||||
DbConnectionMiddleware,
|
||||
Prometheus,
|
||||
AgeLimit,
|
||||
TimeLimit,
|
||||
ShutdownNotifications,
|
||||
Callbacks,
|
||||
Pipelines,
|
||||
Retries,
|
||||
):
|
||||
self.add_middleware(m())
|
||||
for m in middleware or []:
|
||||
self.add_middleware(m)
|
||||
if middleware:
|
||||
raise ImproperlyConfigured(
|
||||
"Middlewares should be set in django settings, not passed directly to the broker."
|
||||
)
|
||||
|
||||
@property
|
||||
def connection(self) -> DatabaseWrapper:
|
||||
@ -100,7 +83,7 @@ class PostgresBroker(Broker):
|
||||
|
||||
@cached_property
|
||||
def model(self) -> type[TaskBase]:
|
||||
return import_string(Conf.task_class)
|
||||
return import_string(Conf().task_class)
|
||||
|
||||
@property
|
||||
def query_set(self) -> QuerySet:
|
||||
@ -267,7 +250,6 @@ class _PostgresConsumer(Consumer):
|
||||
|
||||
@raise_connection_error
|
||||
def ack(self, message: Message):
|
||||
self.unlock_queue.put_nowait(message)
|
||||
self.query_set.filter(
|
||||
message_id=message.message_id,
|
||||
queue_name=message.queue_name,
|
||||
@ -276,11 +258,11 @@ class _PostgresConsumer(Consumer):
|
||||
state=TaskState.DONE,
|
||||
message=message.encode(),
|
||||
)
|
||||
self.unlock_queue.put_nowait(message.message_id)
|
||||
self.in_processing.remove(message.message_id)
|
||||
|
||||
@raise_connection_error
|
||||
def nack(self, message: Message):
|
||||
self.unlock_queue.put_nowait(message)
|
||||
self.query_set.filter(
|
||||
message_id=message.message_id,
|
||||
queue_name=message.queue_name,
|
||||
@ -290,18 +272,18 @@ class _PostgresConsumer(Consumer):
|
||||
state=TaskState.REJECTED,
|
||||
message=message.encode(),
|
||||
)
|
||||
self.unlock_queue.put_nowait(message.message_id)
|
||||
self.in_processing.remove(message.message_id)
|
||||
|
||||
@raise_connection_error
|
||||
def requeue(self, messages: Iterable[Message]):
|
||||
for message in messages:
|
||||
self.unlock_queue.put_nowait(message)
|
||||
self.query_set.filter(
|
||||
message_id__in=[message.message_id for message in messages],
|
||||
).update(
|
||||
state=TaskState.QUEUED,
|
||||
)
|
||||
for message in messages:
|
||||
self.unlock_queue.put_nowait(message.message_id)
|
||||
self.in_processing.remove(message.message_id)
|
||||
self._purge_locks()
|
||||
|
||||
@ -328,9 +310,9 @@ class _PostgresConsumer(Consumer):
|
||||
)
|
||||
self.notifies += notifies
|
||||
|
||||
def _get_message_lock_id(self, message: Message) -> int:
|
||||
def _get_message_lock_id(self, message_id: str) -> int:
|
||||
return _cast_lock_id(
|
||||
f"{channel_name(self.queue_name, ChannelIdentifier.LOCK)}.{message.message_id}"
|
||||
f"{channel_name(self.queue_name, ChannelIdentifier.LOCK)}.{message_id}"
|
||||
)
|
||||
|
||||
def _consume_one(self, message: Message) -> bool:
|
||||
@ -345,7 +327,7 @@ class _PostgresConsumer(Consumer):
|
||||
)
|
||||
.extra(
|
||||
where=["pg_try_advisory_lock(%s)"],
|
||||
params=[self._get_message_lock_id(message)],
|
||||
params=[self._get_message_lock_id(message.message_id)],
|
||||
)
|
||||
.update(
|
||||
state=TaskState.CONSUMED,
|
||||
@ -391,6 +373,7 @@ class _PostgresConsumer(Consumer):
|
||||
notify = self.notifies.pop(0)
|
||||
task = self.query_set.get(message_id=notify.payload)
|
||||
message = Message.decode(task.message)
|
||||
message.task = task
|
||||
if self._consume_one(message):
|
||||
self.in_processing.add(message.message_id)
|
||||
return MessageProxy(message)
|
||||
@ -404,17 +387,18 @@ class _PostgresConsumer(Consumer):
|
||||
def _purge_locks(self):
|
||||
while True:
|
||||
try:
|
||||
message = self.unlock_queue.get(block=False)
|
||||
message_id = self.unlock_queue.get(block=False)
|
||||
except Empty:
|
||||
return
|
||||
self.logger.debug(f"Unlocking {message.message_id}@{message.queue_name}")
|
||||
self.logger.debug(f"Unlocking message {message_id}")
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"SELECT pg_advisory_unlock(%s)", (self._get_message_lock_id(message),)
|
||||
"SELECT pg_advisory_unlock(%s)", (self._get_message_lock_id(message_id),)
|
||||
)
|
||||
self.unlock_queue.task_done()
|
||||
|
||||
def _auto_purge(self):
|
||||
# TODO: allow configuring this
|
||||
# Automatically purge messages on average every 100k iteration.
|
||||
# Dramatiq defaults to 1s, so this means one purge every 28 hours.
|
||||
if randint(0, 100_000): # nosec
|
||||
@ -422,6 +406,7 @@ class _PostgresConsumer(Consumer):
|
||||
self.logger.debug("Running garbage collector")
|
||||
count = self.query_set.filter(
|
||||
state__in=(TaskState.DONE, TaskState.REJECTED),
|
||||
# TODO: allow configuring this
|
||||
mtime__lte=timezone.now() - timezone.timedelta(days=30),
|
||||
result_expiry__lte=timezone.now(),
|
||||
).delete()
|
||||
|
||||
Reference in New Issue
Block a user