@ -67,6 +67,7 @@ SHARED_APPS = [
|
||||
"pgactivity",
|
||||
"pglock",
|
||||
"channels",
|
||||
"authentik.tasks",
|
||||
]
|
||||
TENANT_APPS = [
|
||||
"django.contrib.auth",
|
||||
@ -122,7 +123,6 @@ TENANT_APPS = [
|
||||
"authentik.stages.user_login",
|
||||
"authentik.stages.user_logout",
|
||||
"authentik.stages.user_write",
|
||||
"authentik.tasks",
|
||||
"authentik.brands",
|
||||
"authentik.blueprints",
|
||||
"guardian",
|
||||
|
@ -1,4 +1,8 @@
|
||||
from datetime import timedelta
|
||||
import dramatiq
|
||||
from dramatiq.middleware import AgeLimit, Callbacks, Prometheus, Retries, TimeLimit
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
from authentik.tasks.encoder import JSONPickleEncoder
|
||||
|
||||
|
||||
class AuthentikTasksConfig(ManagedAppConfig):
|
||||
@ -6,3 +10,16 @@ class AuthentikTasksConfig(ManagedAppConfig):
|
||||
label = "authentik_tasks"
|
||||
verbose_name = "authentik Tasks"
|
||||
default = True
|
||||
|
||||
def ready(self) -> None:
|
||||
from authentik.tasks.broker import PostgresBroker
|
||||
|
||||
dramatiq.set_encoder(JSONPickleEncoder())
|
||||
broker = PostgresBroker()
|
||||
broker.add_middleware(Prometheus())
|
||||
broker.add_middleware(AgeLimit(max_age=timedelta(days=30).total_seconds() * 1000))
|
||||
broker.add_middleware(TimeLimit())
|
||||
broker.add_middleware(Callbacks())
|
||||
broker.add_middleware(Retries(max_retries=3))
|
||||
dramatiq.set_broker(broker)
|
||||
return super().ready()
|
||||
|
@ -1,3 +1,4 @@
|
||||
from enum import Enum, StrEnum, auto
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
@ -23,12 +24,18 @@ from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.tasks.models import Queue as MQueue
|
||||
from authentik.tasks.results import PostgresBackend
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def channel_name(connection: DatabaseWrapper, queue_name: str, identifier: str) -> str:
|
||||
return f"{connection.schema_name}.dramatiq.{queue_name}.{identifier}"
|
||||
class ChannelIdentifier(StrEnum):
|
||||
ENQUEUE = auto()
|
||||
LOCK = auto()
|
||||
|
||||
|
||||
def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
|
||||
return f"authentik.tasks.{queue_name}.{identifier.value}"
|
||||
|
||||
|
||||
def raise_connection_error(func):
|
||||
@ -117,6 +124,7 @@ class PostgresBroker(Broker):
|
||||
# TODO: update_or_create
|
||||
self.query_set.create(
|
||||
message_id=message.message_id,
|
||||
tenant=get_current_tenant(),
|
||||
queue_name=message.queue_name,
|
||||
state=MQueue.State.QUEUED,
|
||||
message=message.encode(),
|
||||
@ -196,9 +204,7 @@ class _PostgresConsumer(Consumer):
|
||||
# Should be set to True by Django by default
|
||||
self._listen_connection.set_autocommit(True)
|
||||
with self._listen_connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"LISTEN %s", channel_name(self._listen_connection, self.queue_name, "enqueue")
|
||||
)
|
||||
cursor.execute("LISTEN %s", channel_name(self.queue_name, ChannelIdentifier.ENQUEUE))
|
||||
|
||||
@raise_connection_error
|
||||
def ack(self, message: Message):
|
||||
@ -240,7 +246,7 @@ class _PostgresConsumer(Consumer):
|
||||
notifies = self.query_set.filter(
|
||||
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), queue_name=self.queue_name
|
||||
)
|
||||
channel = channel_name(self.connection, self.queue_name, "enqueue")
|
||||
channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE)
|
||||
return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies]
|
||||
|
||||
def _poll_for_notify(self):
|
||||
@ -251,7 +257,7 @@ class _PostgresConsumer(Consumer):
|
||||
|
||||
def _get_message_lock_id(self, message: Message) -> int:
|
||||
return _cast_lock_id(
|
||||
f"{channel_name(connections[self.connection], self.queue_name, 'lock')}.{message.message_id}" # noqa: E501
|
||||
f"{channel_name(self.queue_name, ChannelIdentifier.LOCK)}.{message.message_id}"
|
||||
)
|
||||
|
||||
def _consume_one(self, message: Message) -> bool:
|
||||
|
32
authentik/tasks/encoder.py
Normal file
32
authentik/tasks/encoder.py
Normal file
@ -0,0 +1,32 @@
|
||||
import jsonpickle
|
||||
import dramatiq.encoder
|
||||
from typing import Any
|
||||
from dramatiq.encoder import MessageData
|
||||
import orjson
|
||||
|
||||
|
||||
class OrjsonBackend(jsonpickle.JSONBackend):
|
||||
def encode(self, obj: Any, indent=None, separators=None) -> str:
|
||||
return orjson.dumps(obj, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
|
||||
|
||||
def decode(self, string: str) -> Any:
|
||||
return orjson.loads(string)
|
||||
|
||||
|
||||
class JSONPickleEncoder(dramatiq.encoder.Encoder):
|
||||
def encode(self, data: MessageData) -> bytes:
|
||||
return jsonpickle.encode(
|
||||
data,
|
||||
backend=OrjsonBackend(),
|
||||
keys=True,
|
||||
warn=True,
|
||||
use_base85=True,
|
||||
).encode("utf-8")
|
||||
|
||||
def decode(self, data: bytes) -> MessageData:
|
||||
return jsonpickle.decode(
|
||||
data.decode("utf-8"),
|
||||
backend=OrjsonBackend(),
|
||||
keys=True,
|
||||
on_missing="warn",
|
||||
)
|
@ -3,6 +3,8 @@ from uuid import uuid4
|
||||
from django.db import models
|
||||
from django.utils import timezone
|
||||
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
|
||||
class Queue(models.Model):
|
||||
class State(models.TextChoices):
|
||||
@ -11,6 +13,7 @@ class Queue(models.Model):
|
||||
REJECTED = "rejected"
|
||||
DONE = "done"
|
||||
|
||||
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
|
||||
message_id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||
queue_name = models.TextField(default="default")
|
||||
state = models.CharField(default=State.QUEUED, choices=State.choices)
|
||||
|
Reference in New Issue
Block a user