Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-03-09 01:10:58 +01:00
parent de54404ab7
commit e8cfc2b91e
5 changed files with 66 additions and 8 deletions

View File

@ -67,6 +67,7 @@ SHARED_APPS = [
"pgactivity",
"pglock",
"channels",
"authentik.tasks",
]
TENANT_APPS = [
"django.contrib.auth",
@ -122,7 +123,6 @@ TENANT_APPS = [
"authentik.stages.user_login",
"authentik.stages.user_logout",
"authentik.stages.user_write",
"authentik.tasks",
"authentik.brands",
"authentik.blueprints",
"guardian",

View File

@ -1,4 +1,8 @@
from datetime import timedelta
import dramatiq
from dramatiq.middleware import AgeLimit, Callbacks, Prometheus, Retries, TimeLimit
from authentik.blueprints.apps import ManagedAppConfig
from authentik.tasks.encoder import JSONPickleEncoder
class AuthentikTasksConfig(ManagedAppConfig):
@ -6,3 +10,16 @@ class AuthentikTasksConfig(ManagedAppConfig):
label = "authentik_tasks"
verbose_name = "authentik Tasks"
default = True
def ready(self) -> None:
from authentik.tasks.broker import PostgresBroker
dramatiq.set_encoder(JSONPickleEncoder())
broker = PostgresBroker()
broker.add_middleware(Prometheus())
broker.add_middleware(AgeLimit(max_age=timedelta(days=30).total_seconds() * 1000))
broker.add_middleware(TimeLimit())
broker.add_middleware(Callbacks())
broker.add_middleware(Retries(max_retries=3))
dramatiq.set_broker(broker)
return super().ready()

View File

@ -1,3 +1,4 @@
from enum import Enum, StrEnum, auto
import functools
import logging
import time
@ -23,12 +24,18 @@ from structlog.stdlib import get_logger
from authentik.tasks.models import Queue as MQueue
from authentik.tasks.results import PostgresBackend
from authentik.tenants.utils import get_current_tenant
LOGGER = get_logger()
def channel_name(connection: DatabaseWrapper, queue_name: str, identifier: str) -> str:
return f"{connection.schema_name}.dramatiq.{queue_name}.{identifier}"
class ChannelIdentifier(StrEnum):
ENQUEUE = auto()
LOCK = auto()
def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
return f"authentik.tasks.{queue_name}.{identifier.value}"
def raise_connection_error(func):
@ -117,6 +124,7 @@ class PostgresBroker(Broker):
# TODO: update_or_create
self.query_set.create(
message_id=message.message_id,
tenant=get_current_tenant(),
queue_name=message.queue_name,
state=MQueue.State.QUEUED,
message=message.encode(),
@ -196,9 +204,7 @@ class _PostgresConsumer(Consumer):
# Should be set to True by Django by default
self._listen_connection.set_autocommit(True)
with self._listen_connection.cursor() as cursor:
cursor.execute(
"LISTEN %s", channel_name(self._listen_connection, self.queue_name, "enqueue")
)
cursor.execute("LISTEN %s", channel_name(self.queue_name, ChannelIdentifier.ENQUEUE))
@raise_connection_error
def ack(self, message: Message):
@ -240,7 +246,7 @@ class _PostgresConsumer(Consumer):
notifies = self.query_set.filter(
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), queue_name=self.queue_name
)
channel = channel_name(self.connection, self.queue_name, "enqueue")
channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE)
return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies]
def _poll_for_notify(self):
@ -251,7 +257,7 @@ class _PostgresConsumer(Consumer):
def _get_message_lock_id(self, message: Message) -> int:
return _cast_lock_id(
f"{channel_name(connections[self.connection], self.queue_name, 'lock')}.{message.message_id}" # noqa: E501
f"{channel_name(self.queue_name, ChannelIdentifier.LOCK)}.{message.message_id}"
)
def _consume_one(self, message: Message) -> bool:

View File

@ -0,0 +1,32 @@
import jsonpickle
import dramatiq.encoder
from typing import Any
from dramatiq.encoder import MessageData
import orjson
class OrjsonBackend(jsonpickle.JSONBackend):
def encode(self, obj: Any, indent=None, separators=None) -> str:
return orjson.dumps(obj, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
def decode(self, string: str) -> Any:
return orjson.loads(string)
class JSONPickleEncoder(dramatiq.encoder.Encoder):
def encode(self, data: MessageData) -> bytes:
return jsonpickle.encode(
data,
backend=OrjsonBackend(),
keys=True,
warn=True,
use_base85=True,
).encode("utf-8")
def decode(self, data: bytes) -> MessageData:
return jsonpickle.decode(
data.decode("utf-8"),
backend=OrjsonBackend(),
keys=True,
on_missing="warn",
)

View File

@ -3,6 +3,8 @@ from uuid import uuid4
from django.db import models
from django.utils import timezone
from authentik.tenants.models import Tenant
class Queue(models.Model):
class State(models.TextChoices):
@ -11,6 +13,7 @@ class Queue(models.Model):
REJECTED = "rejected"
DONE = "done"
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
message_id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
queue_name = models.TextField(default="default")
state = models.CharField(default=State.QUEUED, choices=State.choices)