diff --git a/authentik/tasks/broker.py b/authentik/tasks/broker.py index cd734c89ab..814295fcb2 100644 --- a/authentik/tasks/broker.py +++ b/authentik/tasks/broker.py @@ -1,3 +1,4 @@ +from dramatiq.middleware import Middleware from psycopg import sql import functools import logging @@ -8,7 +9,14 @@ from random import randint import orjson import tenacity -from django.db import DEFAULT_DB_ALIAS, DatabaseError, InterfaceError, OperationalError, connections +from django.db import ( + DEFAULT_DB_ALIAS, + DatabaseError, + InterfaceError, + OperationalError, + close_old_connections, + connections, +) from django.db.backends.postgresql.base import DatabaseWrapper from django.db.models import QuerySet from django.utils import timezone @@ -24,6 +32,7 @@ from structlog.stdlib import get_logger from authentik.tasks.models import Task, CHANNEL_PREFIX, ChannelIdentifier, TaskState from authentik.tasks.results import PostgresBackend +from authentik.tenants.models import Tenant from authentik.tenants.utils import get_current_tenant LOGGER = get_logger() @@ -44,6 +53,29 @@ def raise_connection_error(func): return wrapper +class DbConnectionMiddleware(Middleware): + def _close_old_connections(self, *args, **kwargs): + close_old_connections() + + before_process_message = _close_old_connections + after_process_message = _close_old_connections + + def _close_connections(self, *args, **kwargs): + connections.close_all() + + before_consumer_thread_shutdown = _close_connections + before_worker_thread_shutdown = _close_connections + before_worker_shutdown = _close_connections + + +class TenantMiddleware(Middleware): + def before_process_message(self, broker, message): + Task.objects.select_related("tenant").get(message_id=message.message_id).tenant.activate() + + def after_process_message(self, *args, **kwargs): + Tenant.deactivate() + + class PostgresBroker(Broker): def __init__(self, *args, db_alias: str = DEFAULT_DB_ALIAS, results: bool = True, **kwargs): super().__init__(*args, **kwargs) @@ -52,6 +84,8 @@ class PostgresBroker(Broker): self.queues = set() self.db_alias = db_alias + self.add_middleware(DbConnectionMiddleware()) + self.add_middleware(TenantMiddleware()) if results: self.backend = PostgresBackend() self.add_middleware(Results(backend=self.backend)) @@ -205,7 +239,7 @@ class _PostgresConsumer(Consumer): @property def listen_connection(self) -> DatabaseWrapper: - if self._listen_connection is not None and self._listen_connection.is_usable(): + if self._listen_connection is not None and self._listen_connection.connection is not None: return self._listen_connection self._listen_connection = connections[self.db_alias] # Required for notifications diff --git a/authentik/tasks/results.py b/authentik/tasks/results.py index 22e8a652e1..ff6c7d2bf6 100644 --- a/authentik/tasks/results.py +++ b/authentik/tasks/results.py @@ -30,19 +30,8 @@ class PostgresBackend(ResultBackend): def _store(self, message_key: str, result: Result, ttl: int) -> None: encoder = get_encoder() - query = { - "message_id": message_key, - } - defaults = { - "mtime": timezone.now(), - "result": encoder.encode(result), - "result_ttl": timezone.now() + timezone.timedelta(milliseconds=ttl), - } - # TODO: tenant - create_defaults = { - **query, - **defaults, - "queue_name": "__RQ__", - "state": TaskState.DONE, - } - self.query_set.update_or_create(**query, defaults=defaults, create_defaults=create_defaults) + self.query_set.filter(message_id=message_key).update( + mtime=timezone.now(), + result=encoder.encode(result), + result_ttl=timezone.now() + timezone.timedelta(milliseconds=ttl), + )