@ -13,6 +13,7 @@ class AuthentikTasksConfig(ManagedAppConfig):
|
||||
|
||||
def ready(self) -> None:
|
||||
from authentik.tasks.broker import PostgresBroker
|
||||
from authentik.tasks.middleware import CurrentTask
|
||||
|
||||
dramatiq.set_encoder(JSONPickleEncoder())
|
||||
broker = PostgresBroker()
|
||||
@ -21,5 +22,6 @@ class AuthentikTasksConfig(ManagedAppConfig):
|
||||
broker.add_middleware(TimeLimit())
|
||||
broker.add_middleware(Callbacks())
|
||||
broker.add_middleware(Retries(max_retries=3))
|
||||
broker.add_middleware(CurrentTask())
|
||||
dramatiq.set_broker(broker)
|
||||
return super().ready()
|
||||
|
||||
@ -21,7 +21,7 @@ from psycopg import Notify
|
||||
from psycopg.errors import AdminShutdown
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.tasks.models import Queue as MQueue, CHANNEL_PREFIX, ChannelIdentifier
|
||||
from authentik.tasks.models import Task, CHANNEL_PREFIX, ChannelIdentifier
|
||||
from authentik.tasks.results import PostgresBackend
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
@ -65,7 +65,7 @@ class PostgresBroker(Broker):
|
||||
|
||||
@property
|
||||
def query_set(self) -> QuerySet:
|
||||
return MQueue.objects.using(self.db_alias)
|
||||
return Task.objects.using(self.db_alias)
|
||||
|
||||
def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer:
|
||||
self.declare_queue(queue_name)
|
||||
@ -125,7 +125,7 @@ class PostgresBroker(Broker):
|
||||
defaults = {
|
||||
"tenant": get_current_tenant(),
|
||||
"queue_name": message.queue_name,
|
||||
"state": MQueue.State.QUEUED,
|
||||
"state": Task.State.QUEUED,
|
||||
"message": encoded,
|
||||
}
|
||||
create_defaults = {
|
||||
@ -169,7 +169,7 @@ class PostgresBroker(Broker):
|
||||
if (
|
||||
self.query_set.filter(
|
||||
queue_name=queue_name,
|
||||
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED),
|
||||
state__in=(Task.State.QUEUED, Task.State.CONSUMED),
|
||||
)
|
||||
== 0
|
||||
):
|
||||
@ -200,7 +200,7 @@ class _PostgresConsumer(Consumer):
|
||||
|
||||
@property
|
||||
def query_set(self) -> QuerySet:
|
||||
return MQueue.objects.using(self.db_alias)
|
||||
return Task.objects.using(self.db_alias)
|
||||
|
||||
@property
|
||||
def listen_connection(self) -> DatabaseWrapper:
|
||||
@ -220,9 +220,9 @@ class _PostgresConsumer(Consumer):
|
||||
self.query_set.filter(
|
||||
message_id=message.message_id,
|
||||
queue_name=message.queue_name,
|
||||
state=MQueue.State.CONSUMED,
|
||||
state=Task.State.CONSUMED,
|
||||
).update(
|
||||
state=MQueue.State.DONE,
|
||||
state=Task.State.DONE,
|
||||
message=message.encode(),
|
||||
)
|
||||
self.in_processing.remove(message.message_id)
|
||||
@ -233,9 +233,9 @@ class _PostgresConsumer(Consumer):
|
||||
self.query_set.filter(
|
||||
message_id=message.message_id,
|
||||
queue_name=message.queue_name,
|
||||
state__ne=MQueue.State.REJECTED,
|
||||
state__ne=Task.State.REJECTED,
|
||||
).update(
|
||||
state=MQueue.State.REJECTED,
|
||||
state=Task.State.REJECTED,
|
||||
message=message.encode(),
|
||||
)
|
||||
self.in_processing.remove(message.message_id)
|
||||
@ -245,14 +245,14 @@ class _PostgresConsumer(Consumer):
|
||||
self.query_set.filter(
|
||||
message_id__in=[message.message_id for message in messages],
|
||||
).update(
|
||||
state=MQueue.State.QUEUED,
|
||||
state=Task.State.QUEUED,
|
||||
)
|
||||
# We don't care about locks, requeue occurs on worker stop
|
||||
|
||||
def _fetch_pending_notifies(self) -> list[Notify]:
|
||||
self.logger.debug(f"Polling for lost messages in {self.queue_name}")
|
||||
notifies = self.query_set.filter(
|
||||
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), queue_name=self.queue_name
|
||||
state__in=(Task.State.QUEUED, Task.State.CONSUMED), queue_name=self.queue_name
|
||||
)
|
||||
channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE)
|
||||
return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies]
|
||||
@ -276,9 +276,9 @@ class _PostgresConsumer(Consumer):
|
||||
result = (
|
||||
self.query_set.filter(
|
||||
message_id=message.message_id,
|
||||
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED),
|
||||
state__in=(Task.State.QUEUED, Task.State.CONSUMED),
|
||||
)
|
||||
.update(state=MQueue.State.CONSUMED, mtime=timezone.now())
|
||||
.update(state=Task.State.CONSUMED, mtime=timezone.now())
|
||||
.extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)])
|
||||
)
|
||||
return result == 1
|
||||
@ -354,7 +354,7 @@ class _PostgresConsumer(Consumer):
|
||||
return
|
||||
self.logger.debug("Running garbage collector")
|
||||
count = self.query_set.filter(
|
||||
state__in=(MQueue.State.DONE, MQueue.State.REJECTED),
|
||||
state__in=(Task.State.DONE, Task.State.REJECTED),
|
||||
mtime__lte=timezone.now() - timezone.timedelta(days=30),
|
||||
).delete()
|
||||
self.logger.info(f"Purged {count} messages in all queues")
|
||||
|
||||
19
authentik/tasks/middleware.py
Normal file
19
authentik/tasks/middleware.py
Normal file
@ -0,0 +1,19 @@
|
||||
import contextvars
|
||||
from dramatiq.message import Message
|
||||
from dramatiq.middleware import Middleware
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
|
||||
class CurrentTask(Middleware):
|
||||
_TASK: contextvars.ContextVar[Task | None] = contextvars.ContextVar("_TASK", default=None)
|
||||
|
||||
@classmethod
|
||||
def get_task(cls) -> Task | None:
|
||||
return cls._TASK.get()
|
||||
|
||||
def before_process_message(self, _, message: Message):
|
||||
self._TASK.set(Task.objects.get(message_id=message.message_id))
|
||||
|
||||
def after_process_message(self, *args, **kwargs):
|
||||
self._TASK.get().save()
|
||||
self._TASK.set(None)
|
||||
@ -1,79 +0,0 @@
|
||||
# Generated by Django 5.0.12 on 2025-03-09 00:47
|
||||
|
||||
import django.db.models.deletion
|
||||
import django.utils.timezone
|
||||
import pgtrigger.compiler
|
||||
import pgtrigger.migrations
|
||||
import uuid
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
("authentik_tenants", "0004_tenant_impersonation_require_reason"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="Queue",
|
||||
fields=[
|
||||
(
|
||||
"message_id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
|
||||
),
|
||||
),
|
||||
("queue_name", models.TextField(default="default")),
|
||||
(
|
||||
"state",
|
||||
models.CharField(
|
||||
choices=[
|
||||
("queued", "Queued"),
|
||||
("consumed", "Consumed"),
|
||||
("rejected", "Rejected"),
|
||||
("done", "Done"),
|
||||
],
|
||||
default="queued",
|
||||
),
|
||||
),
|
||||
("mtime", models.DateTimeField(default=django.utils.timezone.now)),
|
||||
("message", models.JSONField(blank=True, null=True)),
|
||||
("result", models.JSONField(blank=True, null=True)),
|
||||
("result_ttl", models.DateTimeField(blank=True, null=True)),
|
||||
(
|
||||
"tenant",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE, to="authentik_tenants.tenant"
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"indexes": [
|
||||
models.Index(fields=["state", "mtime"], name="authentik_t_state_b7ff76_idx")
|
||||
],
|
||||
},
|
||||
),
|
||||
migrations.RunSQL(
|
||||
"ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop
|
||||
),
|
||||
pgtrigger.migrations.AddTrigger(
|
||||
model_name="queue",
|
||||
trigger=pgtrigger.compiler.Trigger(
|
||||
name="notify_enqueueing",
|
||||
sql=pgtrigger.compiler.UpsertTriggerSql(
|
||||
condition="WHEN (NEW.\"state\" = 'queued')",
|
||||
constraint="CONSTRAINT",
|
||||
func="\n PERFORM pg_notify(\n 'authentik.tasks' || NEW.queue_name || '.enqueue',\n CASE WHEN octet_length(NEW.message::text) >= 8000\n THEN jsonb_build_object('message_id', NEW.message_id)::text\n ELSE message::text\n END\n );\n RETURN NEW;\n ",
|
||||
hash="d604c5647f3821f100e8aa7a52be181bde9ebdce",
|
||||
operation="INSERT OR UPDATE",
|
||||
pgid="pgtrigger_notify_enqueueing_b1977",
|
||||
table="authentik_tasks_queue",
|
||||
timing="DEFERRABLE INITIALLY DEFERRED",
|
||||
when="AFTER",
|
||||
),
|
||||
),
|
||||
),
|
||||
]
|
||||
@ -5,6 +5,7 @@ import pgtrigger
|
||||
from django.db import models
|
||||
from django.utils import timezone
|
||||
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.tenants.models import Tenant
|
||||
|
||||
CHANNEL_PREFIX = "authentik.tasks"
|
||||
@ -15,21 +16,23 @@ class ChannelIdentifier(StrEnum):
|
||||
LOCK = auto()
|
||||
|
||||
|
||||
class Queue(models.Model):
|
||||
class Task(SerializerModel):
|
||||
class State(models.TextChoices):
|
||||
QUEUED = "queued"
|
||||
CONSUMED = "consumed"
|
||||
REJECTED = "rejected"
|
||||
DONE = "done"
|
||||
|
||||
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
|
||||
message_id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||
queue_name = models.TextField(default="default")
|
||||
state = models.CharField(default=State.QUEUED, choices=State.choices)
|
||||
mtime = models.DateTimeField(default=timezone.now)
|
||||
message = models.JSONField(blank=True, null=True)
|
||||
result = models.JSONField(blank=True, null=True)
|
||||
result_ttl = models.DateTimeField(blank=True, null=True)
|
||||
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE, editable=False)
|
||||
queue_name = models.TextField(default="default", editable=False)
|
||||
state = models.CharField(default=State.QUEUED, choices=State.choices, editable=False)
|
||||
mtime = models.DateTimeField(default=timezone.now, editable=False)
|
||||
message = models.JSONField(blank=True, null=True, editable=False)
|
||||
result = models.JSONField(blank=True, null=True, editable=False)
|
||||
result_ttl = models.DateTimeField(blank=True, null=True, editable=False)
|
||||
description = models.TextField(blank=True)
|
||||
messages = models.JSONField(blank=True, null=True, editable=False)
|
||||
|
||||
class Meta:
|
||||
indexes = (models.Index(fields=("state", "mtime")),)
|
||||
@ -55,3 +58,8 @@ class Queue(models.Model):
|
||||
|
||||
def __str__(self):
|
||||
return str(self.message_id)
|
||||
|
||||
@property
|
||||
def serializer(self):
|
||||
# TODO: fixme
|
||||
pass
|
||||
|
||||
@ -4,7 +4,7 @@ from django.utils import timezone
|
||||
from dramatiq.message import Message, get_encoder
|
||||
from dramatiq.results.backend import Missing, MResult, Result, ResultBackend
|
||||
|
||||
from authentik.tasks.models import Queue
|
||||
from authentik.tasks.models import Task
|
||||
|
||||
|
||||
class PostgresBackend(ResultBackend):
|
||||
@ -14,7 +14,7 @@ class PostgresBackend(ResultBackend):
|
||||
|
||||
@property
|
||||
def query_set(self) -> QuerySet:
|
||||
return Queue.objects.using(self.db_alias)
|
||||
return Task.objects.using(self.db_alias)
|
||||
|
||||
def build_message_key(self, message: Message) -> str:
|
||||
return str(message.message_id)
|
||||
|
||||
Reference in New Issue
Block a user