Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-03-11 17:25:51 +01:00
parent 3ddc35cddc
commit 677f04cab2
2 changed files with 41 additions and 18 deletions

View File

@ -1,3 +1,4 @@
from dramatiq.middleware import Middleware
from psycopg import sql
import functools
import logging
@ -8,7 +9,14 @@ from random import randint
import orjson
import tenacity
from django.db import DEFAULT_DB_ALIAS, DatabaseError, InterfaceError, OperationalError, connections
from django.db import (
DEFAULT_DB_ALIAS,
DatabaseError,
InterfaceError,
OperationalError,
close_old_connections,
connections,
)
from django.db.backends.postgresql.base import DatabaseWrapper
from django.db.models import QuerySet
from django.utils import timezone
@ -24,6 +32,7 @@ from structlog.stdlib import get_logger
from authentik.tasks.models import Task, CHANNEL_PREFIX, ChannelIdentifier, TaskState
from authentik.tasks.results import PostgresBackend
from authentik.tenants.models import Tenant
from authentik.tenants.utils import get_current_tenant
LOGGER = get_logger()
@ -44,6 +53,29 @@ def raise_connection_error(func):
return wrapper
class DbConnectionMiddleware(Middleware):
def _close_old_connections(self, *args, **kwargs):
close_old_connections()
before_process_message = _close_old_connections
after_process_message = _close_old_connections
def _close_connections(self, *args, **kwargs):
connections.close_all()
before_consumer_thread_shutdown = _close_connections
before_worker_thread_shutdown = _close_connections
before_worker_shutdown = _close_connections
class TenantMiddleware(Middleware):
def before_process_message(self, broker, message):
Task.objects.select_related("tenant").get(message_id=message.message_id).tenant.activate()
def after_process_message(self, *args, **kwargs):
Tenant.deactivate()
class PostgresBroker(Broker):
def __init__(self, *args, db_alias: str = DEFAULT_DB_ALIAS, results: bool = True, **kwargs):
super().__init__(*args, **kwargs)
@ -52,6 +84,8 @@ class PostgresBroker(Broker):
self.queues = set()
self.db_alias = db_alias
self.add_middleware(DbConnectionMiddleware())
self.add_middleware(TenantMiddleware())
if results:
self.backend = PostgresBackend()
self.add_middleware(Results(backend=self.backend))
@ -205,7 +239,7 @@ class _PostgresConsumer(Consumer):
@property
def listen_connection(self) -> DatabaseWrapper:
if self._listen_connection is not None and self._listen_connection.is_usable():
if self._listen_connection is not None and self._listen_connection.connection is not None:
return self._listen_connection
self._listen_connection = connections[self.db_alias]
# Required for notifications

View File

@ -30,19 +30,8 @@ class PostgresBackend(ResultBackend):
def _store(self, message_key: str, result: Result, ttl: int) -> None:
encoder = get_encoder()
query = {
"message_id": message_key,
}
defaults = {
"mtime": timezone.now(),
"result": encoder.encode(result),
"result_ttl": timezone.now() + timezone.timedelta(milliseconds=ttl),
}
# TODO: tenant
create_defaults = {
**query,
**defaults,
"queue_name": "__RQ__",
"state": TaskState.DONE,
}
self.query_set.update_or_create(**query, defaults=defaults, create_defaults=create_defaults)
self.query_set.filter(message_id=message_key).update(
mtime=timezone.now(),
result=encoder.encode(result),
result_ttl=timezone.now() + timezone.timedelta(milliseconds=ttl),
)