diff --git a/authentik/admin/tasks.py b/authentik/admin/tasks.py index 140c6e75d8..5834b9ff2e 100644 --- a/authentik/admin/tasks.py +++ b/authentik/admin/tasks.py @@ -2,6 +2,7 @@ from django.core.cache import cache from django.utils.translation import gettext_lazy as _ +from django_dramatiq_postgres.middleware import CurrentTask from dramatiq import actor from packaging.version import parse from requests import RequestException @@ -12,7 +13,7 @@ from authentik.admin.apps import PROM_INFO from authentik.events.models import Event, EventAction from authentik.lib.config import CONFIG from authentik.lib.utils.http import get_http_session -from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import Task LOGGER = get_logger() VERSION_NULL = "0.0.0" @@ -35,7 +36,7 @@ def _set_prom_info(): @actor def update_latest_version(): """Update latest version info""" - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() if CONFIG.get_bool("disable_update_check"): cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) self.info("Version check disabled.") diff --git a/authentik/blueprints/v1/tasks.py b/authentik/blueprints/v1/tasks.py index da2df8cd65..cfa21e5017 100644 --- a/authentik/blueprints/v1/tasks.py +++ b/authentik/blueprints/v1/tasks.py @@ -10,6 +10,7 @@ from dacite.core import from_dict from django.db import DatabaseError, InternalError, ProgrammingError from django.utils.text import slugify from django.utils.timezone import now +from django_dramatiq_postgres.middleware import CurrentTask from dramatiq.actor import actor from dramatiq.middleware import Middleware from structlog.stdlib import get_logger @@ -35,7 +36,7 @@ from authentik.blueprints.v1.oci import OCI_PREFIX from authentik.events.logs import capture_logs from authentik.events.utils import sanitize_dict from authentik.lib.config import CONFIG -from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import Task from authentik.tasks.schedules.models import Schedule from authentik.tenants.models import Tenant @@ -147,7 +148,7 @@ def blueprints_find() -> list[BlueprintFile]: @actor(throws=(DatabaseError, ProgrammingError, InternalError)) def blueprints_discovery(path: str | None = None): """Find blueprints and check if they need to be created in the database""" - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() count = 0 for blueprint in blueprints_find(): if path and blueprint.path != path: @@ -187,7 +188,7 @@ def check_blueprint_v1_file(blueprint: BlueprintFile): @actor def apply_blueprint(instance_pk: UUID): """Apply single blueprint""" - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() self.set_uid(str(instance_pk)) instance: BlueprintInstance | None = None try: diff --git a/authentik/core/tasks.py b/authentik/core/tasks.py index d5764bbb13..ed43b8dbd6 100644 --- a/authentik/core/tasks.py +++ b/authentik/core/tasks.py @@ -3,6 +3,7 @@ from datetime import datetime, timedelta from django.utils.timezone import now +from django_dramatiq_postgres.middleware import CurrentTask from dramatiq.actor import actor from structlog.stdlib import get_logger @@ -12,7 +13,7 @@ from authentik.core.models import ( ExpiringModel, User, ) -from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import Task LOGGER = get_logger() @@ -20,7 +21,7 @@ LOGGER = get_logger() @actor def clean_expired_models(): """Remove expired objects""" - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() for cls in ExpiringModel.__subclasses__(): cls: ExpiringModel objects = ( @@ -36,7 +37,7 @@ def clean_expired_models(): @actor def clean_temporary_users(): """Remove temporary users created by SAML Sources""" - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() _now = datetime.now() deleted_users = 0 for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}): diff --git a/authentik/crypto/tasks.py b/authentik/crypto/tasks.py index af3f935f6b..84be597b22 100644 --- a/authentik/crypto/tasks.py +++ b/authentik/crypto/tasks.py @@ -6,12 +6,13 @@ from pathlib import Path from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.serialization import load_pem_private_key from cryptography.x509.base import load_pem_x509_certificate +from django_dramatiq_postgres.middleware import CurrentTask from dramatiq.actor import actor from structlog.stdlib import get_logger from authentik.crypto.models import CertificateKeyPair from authentik.lib.config import CONFIG -from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import Task LOGGER = get_logger() @@ -37,7 +38,7 @@ def ensure_certificate_valid(body: str): @actor def certificate_discovery(): """Discover, import and update certificates from the filesystem""" - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() certs = {} private_keys = {} discovered = 0 diff --git a/authentik/enterprise/policies/unique_password/tasks.py b/authentik/enterprise/policies/unique_password/tasks.py index 091119be2f..e0ac67cd6f 100644 --- a/authentik/enterprise/policies/unique_password/tasks.py +++ b/authentik/enterprise/policies/unique_password/tasks.py @@ -1,4 +1,5 @@ from django.db.models.aggregates import Count +from django_dramatiq_postgres.middleware import CurrentTask from dramatiq.actor import actor from structlog import get_logger @@ -6,7 +7,7 @@ from authentik.enterprise.policies.unique_password.models import ( UniquePasswordPolicy, UserPasswordHistory, ) -from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import Task LOGGER = get_logger() @@ -16,7 +17,7 @@ def check_and_purge_password_history(): """Check if any UniquePasswordPolicy exists, and if not, purge the password history table. This is run on a schedule instead of being triggered by policy binding deletion. """ - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() if not UniquePasswordPolicy.objects.exists(): UserPasswordHistory.objects.all().delete() @@ -29,7 +30,7 @@ def check_and_purge_password_history(): @actor def trim_password_histories(): - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() """Removes rows from UserPasswordHistory older than the `n` most recent entries. diff --git a/authentik/enterprise/providers/ssf/tasks.py b/authentik/enterprise/providers/ssf/tasks.py index 64030303a7..d12a8ba36b 100644 --- a/authentik/enterprise/providers/ssf/tasks.py +++ b/authentik/enterprise/providers/ssf/tasks.py @@ -3,6 +3,7 @@ from uuid import UUID from django.http import HttpRequest from django.utils.timezone import now +from django_dramatiq_postgres.middleware import CurrentTask from dramatiq.actor import actor from requests.exceptions import RequestException from structlog.stdlib import get_logger @@ -19,7 +20,7 @@ from authentik.events.logs import LogEvent from authentik.lib.utils.http import get_http_session from authentik.lib.utils.time import timedelta_from_string from authentik.policies.engine import PolicyEngine -from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import Task session = get_http_session() LOGGER = get_logger() @@ -62,7 +63,7 @@ def _check_app_access(stream: Stream, event_data: dict) -> bool: @actor def _send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]): - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() stream = Stream.objects.filter(pk=stream_uuid).first() if not stream: diff --git a/authentik/events/tasks.py b/authentik/events/tasks.py index 15e8a933a5..b80d99dbae 100644 --- a/authentik/events/tasks.py +++ b/authentik/events/tasks.py @@ -3,6 +3,7 @@ from uuid import UUID from django.db.models.query_utils import Q +from django_dramatiq_postgres.middleware import CurrentTask from dramatiq.actor import actor from guardian.shortcuts import get_anonymous_user from structlog.stdlib import get_logger @@ -16,7 +17,7 @@ from authentik.events.models import ( ) from authentik.policies.engine import PolicyEngine from authentik.policies.models import PolicyBinding, PolicyEngineMode -from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import Task LOGGER = get_logger() @@ -110,7 +111,7 @@ def gdpr_cleanup(user_pk: int): @actor def notification_cleanup(): """Cleanup seen notifications and notifications whose event expired.""" - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() notifications = Notification.objects.filter(Q(event=None) | Q(seen=True)) amount = notifications.count() notifications.delete() diff --git a/authentik/lib/sync/outgoing/tasks.py b/authentik/lib/sync/outgoing/tasks.py index 938b12de60..1ca17863af 100644 --- a/authentik/lib/sync/outgoing/tasks.py +++ b/authentik/lib/sync/outgoing/tasks.py @@ -2,6 +2,7 @@ from django.core.paginator import Paginator from django.db.models import Model, QuerySet from django.db.models.query import Q from django.utils.text import slugify +from django_dramatiq_postgres.middleware import CurrentTask from dramatiq.actor import Actor from dramatiq.composition import group from dramatiq.errors import Retry @@ -20,7 +21,6 @@ from authentik.lib.sync.outgoing.exceptions import ( ) from authentik.lib.sync.outgoing.models import OutgoingSyncProvider from authentik.lib.utils.reflection import class_to_path, path_to_class -from authentik.tasks.middleware import CurrentTask from authentik.tasks.models import Task @@ -60,7 +60,7 @@ class SyncTasks: provider_pk: int, sync_objects: Actor, ): - task = CurrentTask.get_task() + task: Task = CurrentTask.get_task() self.logger = get_logger().bind( provider_type=class_to_path(self._provider_model), provider_pk=provider_pk, @@ -114,7 +114,7 @@ class SyncTasks: override_dry_run=False, **filter, ): - task = CurrentTask.get_task() + task: Task = CurrentTask.get_task() _object_type: type[Model] = path_to_class(object_type) self.logger = get_logger().bind( provider_type=class_to_path(self._provider_model), @@ -186,7 +186,7 @@ class SyncTasks: provider_pk: int, raw_op: str, ): - task = CurrentTask.get_task() + task: Task = CurrentTask.get_task() self.logger = get_logger().bind( provider_type=class_to_path(self._provider_model), ) @@ -234,7 +234,7 @@ class SyncTasks: action: str, pk_set: list[int], ): - task = CurrentTask.get_task() + task: Task = CurrentTask.get_task() self.logger = get_logger().bind( provider_type=class_to_path(self._provider_model), ) diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index 70c6ca0a30..17ef324edf 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -12,6 +12,7 @@ from channels.layers import get_channel_layer from django.core.cache import cache from django.db.models.base import Model from django.utils.text import slugify +from django_dramatiq_postgres.middleware import CurrentTask from docker.constants import DEFAULT_UNIX_SOCKET from dramatiq.actor import actor from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME @@ -42,7 +43,7 @@ from authentik.providers.rac.controllers.docker import RACDockerController from authentik.providers.rac.controllers.kubernetes import RACKubernetesController from authentik.providers.radius.controllers.docker import RadiusDockerController from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController -from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import Task LOGGER = get_logger() CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s" @@ -109,7 +110,7 @@ def outpost_service_connection_monitor(connection_pk: Any): @actor def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False): """Create/update/monitor/delete the deployment of an Outpost""" - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() self.set_uid(outpost_pk) logs = [] if from_cache: @@ -144,7 +145,7 @@ def outpost_token_ensurer(): """ Periodically ensure that all Outposts have valid Service Accounts and Tokens """ - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() all_outposts = Outpost.objects.all() for outpost in all_outposts: _ = outpost.token @@ -227,7 +228,7 @@ def _outpost_single_update(outpost: Outpost, layer=None): @actor def outpost_connection_discovery(): """Checks the local environment and create Service connections.""" - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() if not CONFIG.get_bool("outposts.discover"): self.info("Outpost integration discovery is disabled") return diff --git a/authentik/root/settings.py b/authentik/root/settings.py index 2accd91ebc..8cc93187f1 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -367,6 +367,7 @@ DRAMATIQ = { "threads": CONFIG.get_int("worker.threads", 1), }, "middlewares": ( + ("django_dramatiq_postgres.middleware.FullyQualifiedActorName", {}), # TODO: fixme # ("dramatiq.middleware.prometheus.Prometheus", {}), ("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}), @@ -384,10 +385,9 @@ DRAMATIQ = { {"max_retries": 20 if not TEST else 0}, ), # TODO: results - ("authentik.tasks.middleware.FullyQualifiedActorName", {}), - ("authentik.tasks.middleware.RelObjMiddleware", {}), ("authentik.tasks.middleware.TenantMiddleware", {}), - ("authentik.tasks.middleware.CurrentTask", {}), + ("authentik.tasks.middleware.RelObjMiddleware", {}), + ("django_dramatiq_postgres.middleware.CurrentTask", {}), ), "test": TEST, } diff --git a/authentik/sources/kerberos/tasks.py b/authentik/sources/kerberos/tasks.py index 35ee7bf754..698bb05ac3 100644 --- a/authentik/sources/kerberos/tasks.py +++ b/authentik/sources/kerberos/tasks.py @@ -1,6 +1,7 @@ """Kerberos Sync tasks""" from django.core.cache import cache +from django_dramatiq_postgres.middleware import CurrentTask from dramatiq.actor import actor from structlog.stdlib import get_logger @@ -9,7 +10,7 @@ from authentik.lib.sync.outgoing.exceptions import StopSync from authentik.lib.utils.errors import exception_to_string from authentik.sources.kerberos.models import KerberosSource from authentik.sources.kerberos.sync import KerberosSync -from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import Task LOGGER = get_logger() CACHE_KEY_STATUS = "goauthentik.io/sources/kerberos/status/" @@ -30,7 +31,7 @@ def kerberos_connectivity_check(pk: str): @actor(time_limit=(60 * 60 * CONFIG.get_int("sources.kerberos.task_timeout_hours")) * 2.5 * 1000) def kerberos_sync(pk: str): """Sync a single source""" - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() source: KerberosSource = KerberosSource.objects.filter(enabled=True, pk=pk).first() if not source: return diff --git a/authentik/sources/ldap/tasks.py b/authentik/sources/ldap/tasks.py index 9fa36d0c28..627cb8796e 100644 --- a/authentik/sources/ldap/tasks.py +++ b/authentik/sources/ldap/tasks.py @@ -3,6 +3,7 @@ from uuid import uuid4 from django.core.cache import cache +from django_dramatiq_postgres.middleware import CurrentTask from dramatiq.actor import actor from dramatiq.composition import group from dramatiq.message import Message @@ -20,7 +21,7 @@ from authentik.sources.ldap.sync.forward_delete_users import UserLDAPForwardDele from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer from authentik.sources.ldap.sync.users import UserLDAPSynchronizer -from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import Task LOGGER = get_logger() SYNC_CLASSES = [ @@ -118,7 +119,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> @actor(time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000) def ldap_sync_page(source_pk: str, sync_class: str, page_cache_key: str): """Synchronization of an LDAP Source""" - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() if not source: # Because the source couldn't be found, we don't have a UID diff --git a/authentik/sources/oauth/tasks.py b/authentik/sources/oauth/tasks.py index f284d62a21..4c233470c2 100644 --- a/authentik/sources/oauth/tasks.py +++ b/authentik/sources/oauth/tasks.py @@ -2,13 +2,14 @@ from json import dumps +from django_dramatiq_postgres.middleware import CurrentTask from dramatiq.actor import actor from requests import RequestException from structlog.stdlib import get_logger from authentik.lib.utils.http import get_http_session from authentik.sources.oauth.models import OAuthSource -from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import Task LOGGER = get_logger() @@ -16,7 +17,7 @@ LOGGER = get_logger() @actor def update_well_known_jwks(): """Update OAuth sources' config from well_known, and JWKS info from the configured URL""" - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() session = get_http_session() for source in OAuthSource.objects.all().exclude(oidc_well_known_url=""): try: diff --git a/authentik/sources/plex/tasks.py b/authentik/sources/plex/tasks.py index fdb00796de..11bb34e57e 100644 --- a/authentik/sources/plex/tasks.py +++ b/authentik/sources/plex/tasks.py @@ -1,5 +1,6 @@ """Plex tasks""" +from django_dramatiq_postgres.middleware import CurrentTask from dramatiq.actor import actor from requests import RequestException @@ -7,13 +8,13 @@ from authentik.events.models import Event, EventAction from authentik.lib.utils.errors import exception_to_string from authentik.sources.plex.models import PlexSource from authentik.sources.plex.plex import PlexAuth -from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import Task @actor def check_plex_token(source_pk: str): """Check the validity of a Plex source.""" - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() sources = PlexSource.objects.filter(pk=source_pk) if not sources.exists(): return diff --git a/authentik/stages/authenticator_webauthn/tasks.py b/authentik/stages/authenticator_webauthn/tasks.py index e7b9f9e110..ddc2ec91fd 100644 --- a/authentik/stages/authenticator_webauthn/tasks.py +++ b/authentik/stages/authenticator_webauthn/tasks.py @@ -6,6 +6,7 @@ from pathlib import Path from django.core.cache import cache from django.db.transaction import atomic +from django_dramatiq_postgres.middleware import CurrentTask from dramatiq.actor import actor from fido2.mds3 import filter_revoked, parse_blob @@ -13,7 +14,7 @@ from authentik.stages.authenticator_webauthn.models import ( UNKNOWN_DEVICE_TYPE_AAGUID, WebAuthnDeviceType, ) -from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import Task CACHE_KEY_MDS_NO = "goauthentik.io/stages/authenticator_webauthn/mds_no" AAGUID_BLOB_PATH = Path(__file__).parent / "mds" / "aaguid.json" @@ -31,7 +32,7 @@ def mds_ca() -> bytes: @actor def webauthn_mds_import(force=False): """Background task to import FIDO Alliance MDS blob and AAGUIDs into database""" - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() with open(MDS_BLOB_PATH, mode="rb") as _raw_blob: blob = parse_blob(_raw_blob.read(), mds_ca()) to_create_update = [ diff --git a/authentik/stages/email/tasks.py b/authentik/stages/email/tasks.py index 4dc84fd5f5..2bfd831748 100644 --- a/authentik/stages/email/tasks.py +++ b/authentik/stages/email/tasks.py @@ -6,6 +6,7 @@ from typing import Any from django.core.mail import EmailMultiAlternatives from django.core.mail.utils import DNS_NAME from django.utils.text import slugify +from django_dramatiq_postgres.middleware import CurrentTask from dramatiq.actor import actor from dramatiq.composition import group from structlog.stdlib import get_logger @@ -15,7 +16,7 @@ from authentik.lib.utils.reflection import class_to_path, path_to_class from authentik.stages.authenticator_email.models import AuthenticatorEmailStage from authentik.stages.email.models import EmailStage from authentik.stages.email.utils import logo_data -from authentik.tasks.middleware import CurrentTask +from authentik.tasks.models import Task LOGGER = get_logger() @@ -54,7 +55,7 @@ def send_mail( email_stage_pk: str | None = None, ): """Send Email for Email Stage. Retries are scheduled automatically.""" - self = CurrentTask.get_task() + self: Task = CurrentTask.get_task() message_id = make_msgid(domain=DNS_NAME) self.set_uid(slugify(message_id.replace(".", "_").replace("@", "_"))) if not stage_class_path or not email_stage_pk: diff --git a/authentik/tasks/broker.py b/authentik/tasks/broker.py index 41711de4e7..7bf5b92083 100644 --- a/authentik/tasks/broker.py +++ b/authentik/tasks/broker.py @@ -1,21 +1,11 @@ -from typing import Any - +from django.db.models import QuerySet from django_dramatiq_postgres.broker import PostgresBroker -from dramatiq.message import Message from structlog.stdlib import get_logger -from authentik.tenants.utils import get_current_tenant - LOGGER = get_logger() class Broker(PostgresBroker): - def model_defaults(self, message: Message) -> dict[str, Any]: - rel_obj = message.options.get("rel_obj") - if rel_obj: - del message.options["rel_obj"] - return { - "tenant": get_current_tenant(), - "rel_obj": rel_obj, - **super().model_defaults(message), - } + @property + def query_set(self) -> QuerySet: + return self.model.objects.select_related("tenant").using(self.db_alias) diff --git a/authentik/tasks/middleware.py b/authentik/tasks/middleware.py index 4fe6f8b080..9e10fd5b9d 100644 --- a/authentik/tasks/middleware.py +++ b/authentik/tasks/middleware.py @@ -1,16 +1,22 @@ -import contextvars -from typing import Any - -from dramatiq.actor import Actor from dramatiq.broker import Broker from dramatiq.message import Message from dramatiq.middleware import Middleware -from structlog.stdlib import get_logger from authentik.tasks.models import Task from authentik.tenants.models import Tenant +from authentik.tenants.utils import get_current_tenant -LOGGER = get_logger() + +class TenantMiddleware(Middleware): + def before_enqueue(self, broker: Broker, message: Message, delay: int): + message.options["model_defaults"]["tenant"] = get_current_tenant() + + def before_process_message(self, broker: Broker, message: Message): + task: Task = message.options["task"] + task.tenant.activate() + + def after_process_message(self, *args, **kwargs): + Tenant.deactivate() class RelObjMiddleware(Middleware): @@ -18,52 +24,7 @@ class RelObjMiddleware(Middleware): def actor_options(self): return {"rel_obj"} - -class FullyQualifiedActorName(Middleware): - def before_declare_actor(self, broker: Broker, actor: Actor): - actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}" - - -class CurrentTask(Middleware): - _TASK: contextvars.ContextVar[list[Task] | None] = contextvars.ContextVar( - "_TASK", - default=None, - ) - - @classmethod - def get_task(cls) -> Task: - task = cls._TASK.get() - if not task: - raise RuntimeError("CurrentTask.get_task() should only be called in a running task") - return task[-1] - - def before_process_message(self, broker: Broker, message: Message): - tasks = self._TASK.get() - if tasks is None: - tasks = [] - tasks.append(Task.objects.get(message_id=message.message_id)) - self._TASK.set(tasks) - - def after_process_message( - self, - broker: Broker, - message: Message, - *, - result: Any | None = None, - exception: Exception | None = None, - ): - tasks: list[Task] | None = self._TASK.get() - if tasks is None or len(tasks) == 0: - LOGGER.warn("Task was None, not saving. This should not happen") - return - else: - tasks[-1].save() - self._TASK.set(tasks[:-1]) - - -class TenantMiddleware(Middleware): - def before_process_message(self, broker: Broker, message: Message): - Task.objects.select_related("tenant").get(message_id=message.message_id).tenant.activate() - - def after_process_message(self, *args, **kwargs): - Tenant.deactivate() + def before_enqueue(self, broker: Broker, message: Message, delay: int): + if rel_obj := message.options.get("rel_obj"): + del message.options["rel_obj"] + message.options["model_defaults"]["rel_obj"] = rel_obj diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py index 68408913b8..de958e83fc 100644 --- a/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py @@ -23,6 +23,7 @@ from django.utils.module_loading import import_string from dramatiq.broker import Broker, Consumer, MessageProxy from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name from dramatiq.errors import ConnectionError, QueueJoinTimeout +from dramatiq.logging import get_logger from dramatiq.message import Message from dramatiq.middleware import ( Middleware, @@ -30,12 +31,11 @@ from dramatiq.middleware import ( from pglock.core import _cast_lock_id from psycopg import Notify, sql from psycopg.errors import AdminShutdown -from structlog.stdlib import get_logger from django_dramatiq_postgres.conf import Conf from django_dramatiq_postgres.models import CHANNEL_PREFIX, ChannelIdentifier, TaskBase, TaskState -LOGGER = get_logger() +logger = get_logger(__name__) def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str: @@ -62,7 +62,7 @@ class PostgresBroker(Broker): **kwargs, ): super().__init__(*args, middleware=[], **kwargs) - self.logger = get_logger().bind() + self.logger = get_logger(__name__, type(self)) self.queues = set() @@ -131,7 +131,7 @@ class PostgresBroker(Broker): reraise=True, wait=tenacity.wait_random_exponential(multiplier=1, max=30), stop=tenacity.stop_after_attempt(10), - before_sleep=tenacity.before_sleep_log(LOGGER, logging.INFO, exc_info=True), + before_sleep=tenacity.before_sleep_log(logger, logging.INFO, exc_info=True), ) def enqueue(self, message: Message, *, delay: int | None = None) -> Message: canonical_queue_name = message.queue_name @@ -148,20 +148,26 @@ class PostgresBroker(Broker): self.declare_queue(canonical_queue_name) self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}") + + message.options["model_defaults"] = self.model_defaults(message) self.emit_before("enqueue", message, delay) + query = { "message_id": message.message_id, } - defaults = self.model_defaults(message) + defaults = message.options["model_defaults"] + del message.options["model_defaults"] create_defaults = { **query, **defaults, } + self.query_set.update_or_create( **query, defaults=defaults, create_defaults=create_defaults, ) + self.emit_after("enqueue", message, delay) return message @@ -209,7 +215,7 @@ class _PostgresConsumer(Consumer): timeout: int, **kwargs, ): - self.logger = get_logger().bind() + self.logger = get_logger(__name__, type(self)) self.notifies: list[Notify] = [] self.broker = broker diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py index 0f931b7ebd..79175934e0 100644 --- a/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py @@ -1,10 +1,18 @@ +import contextvars +from typing import Any + from django.db import ( close_old_connections, connections, ) +from dramatiq.actor import Actor +from dramatiq.broker import Broker +from dramatiq.logging import get_logger +from dramatiq.message import Message from dramatiq.middleware.middleware import Middleware from django_dramatiq_postgres.conf import Conf +from django_dramatiq_postgres.models import TaskBase class DbConnectionMiddleware(Middleware): @@ -22,3 +30,49 @@ class DbConnectionMiddleware(Middleware): before_consumer_thread_shutdown = _close_connections before_worker_thread_shutdown = _close_connections before_worker_shutdown = _close_connections + + +class FullyQualifiedActorName(Middleware): + def before_declare_actor(self, broker: Broker, actor: Actor): + actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}" + + +class CurrentTask(Middleware): + def __init__(self): + self.logger = get_logger(__name__, type(self)) + + # This is a list of tasks, so that in tests, when a task calls another task, this acts as a pile + _TASKS: contextvars.ContextVar[list[TaskBase] | None] = contextvars.ContextVar( + "_TASKS", + default=None, + ) + + @classmethod + def get_task(cls) -> TaskBase: + task = cls._TASKS.get() + if not task: + raise RuntimeError("CurrentTask.get_task() can only be called in a running task") + return task[-1] + + def before_process_message(self, broker: Broker, message: Message): + tasks = self._TASKS.get() + if tasks is None: + tasks = [] + tasks.append(message.options["task"]) + self._TASKS.set(tasks) + + def after_process_message( + self, + broker: Broker, + message: Message, + *, + result: Any | None = None, + exception: Exception | None = None, + ): + tasks: list[TaskBase] | None = self._TASKS.get() + if tasks is None or len(tasks) == 0: + self.logger.warning("Task was None, not saving. This should not happen.") + return + else: + tasks[-1].save() + self._TASKS.set(tasks[:-1])