Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-03-09 02:30:38 +01:00
parent 2b1ee8cd5c
commit c4b988c632
6 changed files with 53 additions and 103 deletions

View File

@ -13,6 +13,7 @@ class AuthentikTasksConfig(ManagedAppConfig):
def ready(self) -> None:
from authentik.tasks.broker import PostgresBroker
from authentik.tasks.middleware import CurrentTask
dramatiq.set_encoder(JSONPickleEncoder())
broker = PostgresBroker()
@ -21,5 +22,6 @@ class AuthentikTasksConfig(ManagedAppConfig):
broker.add_middleware(TimeLimit())
broker.add_middleware(Callbacks())
broker.add_middleware(Retries(max_retries=3))
broker.add_middleware(CurrentTask())
dramatiq.set_broker(broker)
return super().ready()

View File

@ -21,7 +21,7 @@ from psycopg import Notify
from psycopg.errors import AdminShutdown
from structlog.stdlib import get_logger
from authentik.tasks.models import Queue as MQueue, CHANNEL_PREFIX, ChannelIdentifier
from authentik.tasks.models import Task, CHANNEL_PREFIX, ChannelIdentifier
from authentik.tasks.results import PostgresBackend
from authentik.tenants.utils import get_current_tenant
@ -65,7 +65,7 @@ class PostgresBroker(Broker):
@property
def query_set(self) -> QuerySet:
return MQueue.objects.using(self.db_alias)
return Task.objects.using(self.db_alias)
def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer:
self.declare_queue(queue_name)
@ -125,7 +125,7 @@ class PostgresBroker(Broker):
defaults = {
"tenant": get_current_tenant(),
"queue_name": message.queue_name,
"state": MQueue.State.QUEUED,
"state": Task.State.QUEUED,
"message": encoded,
}
create_defaults = {
@ -169,7 +169,7 @@ class PostgresBroker(Broker):
if (
self.query_set.filter(
queue_name=queue_name,
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED),
state__in=(Task.State.QUEUED, Task.State.CONSUMED),
)
== 0
):
@ -200,7 +200,7 @@ class _PostgresConsumer(Consumer):
@property
def query_set(self) -> QuerySet:
return MQueue.objects.using(self.db_alias)
return Task.objects.using(self.db_alias)
@property
def listen_connection(self) -> DatabaseWrapper:
@ -220,9 +220,9 @@ class _PostgresConsumer(Consumer):
self.query_set.filter(
message_id=message.message_id,
queue_name=message.queue_name,
state=MQueue.State.CONSUMED,
state=Task.State.CONSUMED,
).update(
state=MQueue.State.DONE,
state=Task.State.DONE,
message=message.encode(),
)
self.in_processing.remove(message.message_id)
@ -233,9 +233,9 @@ class _PostgresConsumer(Consumer):
self.query_set.filter(
message_id=message.message_id,
queue_name=message.queue_name,
state__ne=MQueue.State.REJECTED,
state__ne=Task.State.REJECTED,
).update(
state=MQueue.State.REJECTED,
state=Task.State.REJECTED,
message=message.encode(),
)
self.in_processing.remove(message.message_id)
@ -245,14 +245,14 @@ class _PostgresConsumer(Consumer):
self.query_set.filter(
message_id__in=[message.message_id for message in messages],
).update(
state=MQueue.State.QUEUED,
state=Task.State.QUEUED,
)
# We don't care about locks, requeue occurs on worker stop
def _fetch_pending_notifies(self) -> list[Notify]:
self.logger.debug(f"Polling for lost messages in {self.queue_name}")
notifies = self.query_set.filter(
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), queue_name=self.queue_name
state__in=(Task.State.QUEUED, Task.State.CONSUMED), queue_name=self.queue_name
)
channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE)
return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies]
@ -276,9 +276,9 @@ class _PostgresConsumer(Consumer):
result = (
self.query_set.filter(
message_id=message.message_id,
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED),
state__in=(Task.State.QUEUED, Task.State.CONSUMED),
)
.update(state=MQueue.State.CONSUMED, mtime=timezone.now())
.update(state=Task.State.CONSUMED, mtime=timezone.now())
.extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)])
)
return result == 1
@ -354,7 +354,7 @@ class _PostgresConsumer(Consumer):
return
self.logger.debug("Running garbage collector")
count = self.query_set.filter(
state__in=(MQueue.State.DONE, MQueue.State.REJECTED),
state__in=(Task.State.DONE, Task.State.REJECTED),
mtime__lte=timezone.now() - timezone.timedelta(days=30),
).delete()
self.logger.info(f"Purged {count} messages in all queues")

View File

@ -0,0 +1,19 @@
import contextvars
from dramatiq.message import Message
from dramatiq.middleware import Middleware
from authentik.tasks.models import Task
class CurrentTask(Middleware):
_TASK: contextvars.ContextVar[Task | None] = contextvars.ContextVar("_TASK", default=None)
@classmethod
def get_task(cls) -> Task | None:
return cls._TASK.get()
def before_process_message(self, _, message: Message):
self._TASK.set(Task.objects.get(message_id=message.message_id))
def after_process_message(self, *args, **kwargs):
self._TASK.get().save()
self._TASK.set(None)

View File

@ -1,79 +0,0 @@
# Generated by Django 5.0.12 on 2025-03-09 00:47
import django.db.models.deletion
import django.utils.timezone
import pgtrigger.compiler
import pgtrigger.migrations
import uuid
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
("authentik_tenants", "0004_tenant_impersonation_require_reason"),
]
operations = [
migrations.CreateModel(
name="Queue",
fields=[
(
"message_id",
models.UUIDField(
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
),
),
("queue_name", models.TextField(default="default")),
(
"state",
models.CharField(
choices=[
("queued", "Queued"),
("consumed", "Consumed"),
("rejected", "Rejected"),
("done", "Done"),
],
default="queued",
),
),
("mtime", models.DateTimeField(default=django.utils.timezone.now)),
("message", models.JSONField(blank=True, null=True)),
("result", models.JSONField(blank=True, null=True)),
("result_ttl", models.DateTimeField(blank=True, null=True)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="authentik_tenants.tenant"
),
),
],
options={
"indexes": [
models.Index(fields=["state", "mtime"], name="authentik_t_state_b7ff76_idx")
],
},
),
migrations.RunSQL(
"ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop
),
pgtrigger.migrations.AddTrigger(
model_name="queue",
trigger=pgtrigger.compiler.Trigger(
name="notify_enqueueing",
sql=pgtrigger.compiler.UpsertTriggerSql(
condition="WHEN (NEW.\"state\" = 'queued')",
constraint="CONSTRAINT",
func="\n PERFORM pg_notify(\n 'authentik.tasks' || NEW.queue_name || '.enqueue',\n CASE WHEN octet_length(NEW.message::text) >= 8000\n THEN jsonb_build_object('message_id', NEW.message_id)::text\n ELSE message::text\n END\n );\n RETURN NEW;\n ",
hash="d604c5647f3821f100e8aa7a52be181bde9ebdce",
operation="INSERT OR UPDATE",
pgid="pgtrigger_notify_enqueueing_b1977",
table="authentik_tasks_queue",
timing="DEFERRABLE INITIALLY DEFERRED",
when="AFTER",
),
),
),
]

View File

@ -5,6 +5,7 @@ import pgtrigger
from django.db import models
from django.utils import timezone
from authentik.lib.models import SerializerModel
from authentik.tenants.models import Tenant
CHANNEL_PREFIX = "authentik.tasks"
@ -15,21 +16,23 @@ class ChannelIdentifier(StrEnum):
LOCK = auto()
class Queue(models.Model):
class Task(SerializerModel):
class State(models.TextChoices):
QUEUED = "queued"
CONSUMED = "consumed"
REJECTED = "rejected"
DONE = "done"
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
message_id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
queue_name = models.TextField(default="default")
state = models.CharField(default=State.QUEUED, choices=State.choices)
mtime = models.DateTimeField(default=timezone.now)
message = models.JSONField(blank=True, null=True)
result = models.JSONField(blank=True, null=True)
result_ttl = models.DateTimeField(blank=True, null=True)
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE, editable=False)
queue_name = models.TextField(default="default", editable=False)
state = models.CharField(default=State.QUEUED, choices=State.choices, editable=False)
mtime = models.DateTimeField(default=timezone.now, editable=False)
message = models.JSONField(blank=True, null=True, editable=False)
result = models.JSONField(blank=True, null=True, editable=False)
result_ttl = models.DateTimeField(blank=True, null=True, editable=False)
description = models.TextField(blank=True)
messages = models.JSONField(blank=True, null=True, editable=False)
class Meta:
indexes = (models.Index(fields=("state", "mtime")),)
@ -55,3 +58,8 @@ class Queue(models.Model):
def __str__(self):
return str(self.message_id)
@property
def serializer(self):
# TODO: fixme
pass

View File

@ -4,7 +4,7 @@ from django.utils import timezone
from dramatiq.message import Message, get_encoder
from dramatiq.results.backend import Missing, MResult, Result, ResultBackend
from authentik.tasks.models import Queue
from authentik.tasks.models import Task
class PostgresBackend(ResultBackend):
@ -14,7 +14,7 @@ class PostgresBackend(ResultBackend):
@property
def query_set(self) -> QuerySet:
return Queue.objects.using(self.db_alias)
return Task.objects.using(self.db_alias)
def build_message_key(self, message: Message) -> str:
return str(message.message_id)