@ -67,6 +67,7 @@ SHARED_APPS = [
|
|||||||
"pgactivity",
|
"pgactivity",
|
||||||
"pglock",
|
"pglock",
|
||||||
"channels",
|
"channels",
|
||||||
|
"authentik.tasks",
|
||||||
]
|
]
|
||||||
TENANT_APPS = [
|
TENANT_APPS = [
|
||||||
"django.contrib.auth",
|
"django.contrib.auth",
|
||||||
@ -122,7 +123,6 @@ TENANT_APPS = [
|
|||||||
"authentik.stages.user_login",
|
"authentik.stages.user_login",
|
||||||
"authentik.stages.user_logout",
|
"authentik.stages.user_logout",
|
||||||
"authentik.stages.user_write",
|
"authentik.stages.user_write",
|
||||||
"authentik.tasks",
|
|
||||||
"authentik.brands",
|
"authentik.brands",
|
||||||
"authentik.blueprints",
|
"authentik.blueprints",
|
||||||
"guardian",
|
"guardian",
|
||||||
|
@ -1,4 +1,8 @@
|
|||||||
|
from datetime import timedelta
|
||||||
|
import dramatiq
|
||||||
|
from dramatiq.middleware import AgeLimit, Callbacks, Prometheus, Retries, TimeLimit
|
||||||
from authentik.blueprints.apps import ManagedAppConfig
|
from authentik.blueprints.apps import ManagedAppConfig
|
||||||
|
from authentik.tasks.encoder import JSONPickleEncoder
|
||||||
|
|
||||||
|
|
||||||
class AuthentikTasksConfig(ManagedAppConfig):
|
class AuthentikTasksConfig(ManagedAppConfig):
|
||||||
@ -6,3 +10,16 @@ class AuthentikTasksConfig(ManagedAppConfig):
|
|||||||
label = "authentik_tasks"
|
label = "authentik_tasks"
|
||||||
verbose_name = "authentik Tasks"
|
verbose_name = "authentik Tasks"
|
||||||
default = True
|
default = True
|
||||||
|
|
||||||
|
def ready(self) -> None:
|
||||||
|
from authentik.tasks.broker import PostgresBroker
|
||||||
|
|
||||||
|
dramatiq.set_encoder(JSONPickleEncoder())
|
||||||
|
broker = PostgresBroker()
|
||||||
|
broker.add_middleware(Prometheus())
|
||||||
|
broker.add_middleware(AgeLimit(max_age=timedelta(days=30).total_seconds() * 1000))
|
||||||
|
broker.add_middleware(TimeLimit())
|
||||||
|
broker.add_middleware(Callbacks())
|
||||||
|
broker.add_middleware(Retries(max_retries=3))
|
||||||
|
dramatiq.set_broker(broker)
|
||||||
|
return super().ready()
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from enum import Enum, StrEnum, auto
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
@ -23,12 +24,18 @@ from structlog.stdlib import get_logger
|
|||||||
|
|
||||||
from authentik.tasks.models import Queue as MQueue
|
from authentik.tasks.models import Queue as MQueue
|
||||||
from authentik.tasks.results import PostgresBackend
|
from authentik.tasks.results import PostgresBackend
|
||||||
|
from authentik.tenants.utils import get_current_tenant
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
|
||||||
def channel_name(connection: DatabaseWrapper, queue_name: str, identifier: str) -> str:
|
class ChannelIdentifier(StrEnum):
|
||||||
return f"{connection.schema_name}.dramatiq.{queue_name}.{identifier}"
|
ENQUEUE = auto()
|
||||||
|
LOCK = auto()
|
||||||
|
|
||||||
|
|
||||||
|
def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
|
||||||
|
return f"authentik.tasks.{queue_name}.{identifier.value}"
|
||||||
|
|
||||||
|
|
||||||
def raise_connection_error(func):
|
def raise_connection_error(func):
|
||||||
@ -117,6 +124,7 @@ class PostgresBroker(Broker):
|
|||||||
# TODO: update_or_create
|
# TODO: update_or_create
|
||||||
self.query_set.create(
|
self.query_set.create(
|
||||||
message_id=message.message_id,
|
message_id=message.message_id,
|
||||||
|
tenant=get_current_tenant(),
|
||||||
queue_name=message.queue_name,
|
queue_name=message.queue_name,
|
||||||
state=MQueue.State.QUEUED,
|
state=MQueue.State.QUEUED,
|
||||||
message=message.encode(),
|
message=message.encode(),
|
||||||
@ -196,9 +204,7 @@ class _PostgresConsumer(Consumer):
|
|||||||
# Should be set to True by Django by default
|
# Should be set to True by Django by default
|
||||||
self._listen_connection.set_autocommit(True)
|
self._listen_connection.set_autocommit(True)
|
||||||
with self._listen_connection.cursor() as cursor:
|
with self._listen_connection.cursor() as cursor:
|
||||||
cursor.execute(
|
cursor.execute("LISTEN %s", channel_name(self.queue_name, ChannelIdentifier.ENQUEUE))
|
||||||
"LISTEN %s", channel_name(self._listen_connection, self.queue_name, "enqueue")
|
|
||||||
)
|
|
||||||
|
|
||||||
@raise_connection_error
|
@raise_connection_error
|
||||||
def ack(self, message: Message):
|
def ack(self, message: Message):
|
||||||
@ -240,7 +246,7 @@ class _PostgresConsumer(Consumer):
|
|||||||
notifies = self.query_set.filter(
|
notifies = self.query_set.filter(
|
||||||
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), queue_name=self.queue_name
|
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), queue_name=self.queue_name
|
||||||
)
|
)
|
||||||
channel = channel_name(self.connection, self.queue_name, "enqueue")
|
channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE)
|
||||||
return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies]
|
return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies]
|
||||||
|
|
||||||
def _poll_for_notify(self):
|
def _poll_for_notify(self):
|
||||||
@ -251,7 +257,7 @@ class _PostgresConsumer(Consumer):
|
|||||||
|
|
||||||
def _get_message_lock_id(self, message: Message) -> int:
|
def _get_message_lock_id(self, message: Message) -> int:
|
||||||
return _cast_lock_id(
|
return _cast_lock_id(
|
||||||
f"{channel_name(connections[self.connection], self.queue_name, 'lock')}.{message.message_id}" # noqa: E501
|
f"{channel_name(self.queue_name, ChannelIdentifier.LOCK)}.{message.message_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _consume_one(self, message: Message) -> bool:
|
def _consume_one(self, message: Message) -> bool:
|
||||||
|
32
authentik/tasks/encoder.py
Normal file
32
authentik/tasks/encoder.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import jsonpickle
|
||||||
|
import dramatiq.encoder
|
||||||
|
from typing import Any
|
||||||
|
from dramatiq.encoder import MessageData
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
|
||||||
|
class OrjsonBackend(jsonpickle.JSONBackend):
|
||||||
|
def encode(self, obj: Any, indent=None, separators=None) -> str:
|
||||||
|
return orjson.dumps(obj, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
|
||||||
|
|
||||||
|
def decode(self, string: str) -> Any:
|
||||||
|
return orjson.loads(string)
|
||||||
|
|
||||||
|
|
||||||
|
class JSONPickleEncoder(dramatiq.encoder.Encoder):
|
||||||
|
def encode(self, data: MessageData) -> bytes:
|
||||||
|
return jsonpickle.encode(
|
||||||
|
data,
|
||||||
|
backend=OrjsonBackend(),
|
||||||
|
keys=True,
|
||||||
|
warn=True,
|
||||||
|
use_base85=True,
|
||||||
|
).encode("utf-8")
|
||||||
|
|
||||||
|
def decode(self, data: bytes) -> MessageData:
|
||||||
|
return jsonpickle.decode(
|
||||||
|
data.decode("utf-8"),
|
||||||
|
backend=OrjsonBackend(),
|
||||||
|
keys=True,
|
||||||
|
on_missing="warn",
|
||||||
|
)
|
@ -3,6 +3,8 @@ from uuid import uuid4
|
|||||||
from django.db import models
|
from django.db import models
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
|
from authentik.tenants.models import Tenant
|
||||||
|
|
||||||
|
|
||||||
class Queue(models.Model):
|
class Queue(models.Model):
|
||||||
class State(models.TextChoices):
|
class State(models.TextChoices):
|
||||||
@ -11,6 +13,7 @@ class Queue(models.Model):
|
|||||||
REJECTED = "rejected"
|
REJECTED = "rejected"
|
||||||
DONE = "done"
|
DONE = "done"
|
||||||
|
|
||||||
|
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
|
||||||
message_id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
message_id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||||
queue_name = models.TextField(default="default")
|
queue_name = models.TextField(default="default")
|
||||||
state = models.CharField(default=State.QUEUED, choices=State.choices)
|
state = models.CharField(default=State.QUEUED, choices=State.choices)
|
||||||
|
Reference in New Issue
Block a user