@ -1,3 +1,4 @@
|
||||
from dramatiq.middleware import Middleware
|
||||
from psycopg import sql
|
||||
import functools
|
||||
import logging
|
||||
@ -8,7 +9,14 @@ from random import randint
|
||||
|
||||
import orjson
|
||||
import tenacity
|
||||
from django.db import DEFAULT_DB_ALIAS, DatabaseError, InterfaceError, OperationalError, connections
|
||||
from django.db import (
|
||||
DEFAULT_DB_ALIAS,
|
||||
DatabaseError,
|
||||
InterfaceError,
|
||||
OperationalError,
|
||||
close_old_connections,
|
||||
connections,
|
||||
)
|
||||
from django.db.backends.postgresql.base import DatabaseWrapper
|
||||
from django.db.models import QuerySet
|
||||
from django.utils import timezone
|
||||
@ -24,6 +32,7 @@ from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.tasks.models import Task, CHANNEL_PREFIX, ChannelIdentifier, TaskState
|
||||
from authentik.tasks.results import PostgresBackend
|
||||
from authentik.tenants.models import Tenant
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
LOGGER = get_logger()
|
||||
@ -44,6 +53,29 @@ def raise_connection_error(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
class DbConnectionMiddleware(Middleware):
|
||||
def _close_old_connections(self, *args, **kwargs):
|
||||
close_old_connections()
|
||||
|
||||
before_process_message = _close_old_connections
|
||||
after_process_message = _close_old_connections
|
||||
|
||||
def _close_connections(self, *args, **kwargs):
|
||||
connections.close_all()
|
||||
|
||||
before_consumer_thread_shutdown = _close_connections
|
||||
before_worker_thread_shutdown = _close_connections
|
||||
before_worker_shutdown = _close_connections
|
||||
|
||||
|
||||
class TenantMiddleware(Middleware):
|
||||
def before_process_message(self, broker, message):
|
||||
Task.objects.select_related("tenant").get(message_id=message.message_id).tenant.activate()
|
||||
|
||||
def after_process_message(self, *args, **kwargs):
|
||||
Tenant.deactivate()
|
||||
|
||||
|
||||
class PostgresBroker(Broker):
|
||||
def __init__(self, *args, db_alias: str = DEFAULT_DB_ALIAS, results: bool = True, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -52,6 +84,8 @@ class PostgresBroker(Broker):
|
||||
self.queues = set()
|
||||
|
||||
self.db_alias = db_alias
|
||||
self.add_middleware(DbConnectionMiddleware())
|
||||
self.add_middleware(TenantMiddleware())
|
||||
if results:
|
||||
self.backend = PostgresBackend()
|
||||
self.add_middleware(Results(backend=self.backend))
|
||||
@ -205,7 +239,7 @@ class _PostgresConsumer(Consumer):
|
||||
|
||||
@property
|
||||
def listen_connection(self) -> DatabaseWrapper:
|
||||
if self._listen_connection is not None and self._listen_connection.is_usable():
|
||||
if self._listen_connection is not None and self._listen_connection.connection is not None:
|
||||
return self._listen_connection
|
||||
self._listen_connection = connections[self.db_alias]
|
||||
# Required for notifications
|
||||
|
@ -30,19 +30,8 @@ class PostgresBackend(ResultBackend):
|
||||
|
||||
def _store(self, message_key: str, result: Result, ttl: int) -> None:
|
||||
encoder = get_encoder()
|
||||
query = {
|
||||
"message_id": message_key,
|
||||
}
|
||||
defaults = {
|
||||
"mtime": timezone.now(),
|
||||
"result": encoder.encode(result),
|
||||
"result_ttl": timezone.now() + timezone.timedelta(milliseconds=ttl),
|
||||
}
|
||||
# TODO: tenant
|
||||
create_defaults = {
|
||||
**query,
|
||||
**defaults,
|
||||
"queue_name": "__RQ__",
|
||||
"state": TaskState.DONE,
|
||||
}
|
||||
self.query_set.update_or_create(**query, defaults=defaults, create_defaults=create_defaults)
|
||||
self.query_set.filter(message_id=message_key).update(
|
||||
mtime=timezone.now(),
|
||||
result=encoder.encode(result),
|
||||
result_ttl=timezone.now() + timezone.timedelta(milliseconds=ttl),
|
||||
)
|
||||
|
Reference in New Issue
Block a user