Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-03-09 02:30:38 +01:00
parent 2b1ee8cd5c
commit c4b988c632
6 changed files with 53 additions and 103 deletions

View File

@ -13,6 +13,7 @@ class AuthentikTasksConfig(ManagedAppConfig):
def ready(self) -> None: def ready(self) -> None:
from authentik.tasks.broker import PostgresBroker from authentik.tasks.broker import PostgresBroker
from authentik.tasks.middleware import CurrentTask
dramatiq.set_encoder(JSONPickleEncoder()) dramatiq.set_encoder(JSONPickleEncoder())
broker = PostgresBroker() broker = PostgresBroker()
@ -21,5 +22,6 @@ class AuthentikTasksConfig(ManagedAppConfig):
broker.add_middleware(TimeLimit()) broker.add_middleware(TimeLimit())
broker.add_middleware(Callbacks()) broker.add_middleware(Callbacks())
broker.add_middleware(Retries(max_retries=3)) broker.add_middleware(Retries(max_retries=3))
broker.add_middleware(CurrentTask())
dramatiq.set_broker(broker) dramatiq.set_broker(broker)
return super().ready() return super().ready()

View File

@ -21,7 +21,7 @@ from psycopg import Notify
from psycopg.errors import AdminShutdown from psycopg.errors import AdminShutdown
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.tasks.models import Queue as MQueue, CHANNEL_PREFIX, ChannelIdentifier from authentik.tasks.models import Task, CHANNEL_PREFIX, ChannelIdentifier
from authentik.tasks.results import PostgresBackend from authentik.tasks.results import PostgresBackend
from authentik.tenants.utils import get_current_tenant from authentik.tenants.utils import get_current_tenant
@ -65,7 +65,7 @@ class PostgresBroker(Broker):
@property @property
def query_set(self) -> QuerySet: def query_set(self) -> QuerySet:
return MQueue.objects.using(self.db_alias) return Task.objects.using(self.db_alias)
def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer: def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer:
self.declare_queue(queue_name) self.declare_queue(queue_name)
@ -125,7 +125,7 @@ class PostgresBroker(Broker):
defaults = { defaults = {
"tenant": get_current_tenant(), "tenant": get_current_tenant(),
"queue_name": message.queue_name, "queue_name": message.queue_name,
"state": MQueue.State.QUEUED, "state": Task.State.QUEUED,
"message": encoded, "message": encoded,
} }
create_defaults = { create_defaults = {
@ -169,7 +169,7 @@ class PostgresBroker(Broker):
if ( if (
self.query_set.filter( self.query_set.filter(
queue_name=queue_name, queue_name=queue_name,
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), state__in=(Task.State.QUEUED, Task.State.CONSUMED),
) )
== 0 == 0
): ):
@ -200,7 +200,7 @@ class _PostgresConsumer(Consumer):
@property @property
def query_set(self) -> QuerySet: def query_set(self) -> QuerySet:
return MQueue.objects.using(self.db_alias) return Task.objects.using(self.db_alias)
@property @property
def listen_connection(self) -> DatabaseWrapper: def listen_connection(self) -> DatabaseWrapper:
@ -220,9 +220,9 @@ class _PostgresConsumer(Consumer):
self.query_set.filter( self.query_set.filter(
message_id=message.message_id, message_id=message.message_id,
queue_name=message.queue_name, queue_name=message.queue_name,
state=MQueue.State.CONSUMED, state=Task.State.CONSUMED,
).update( ).update(
state=MQueue.State.DONE, state=Task.State.DONE,
message=message.encode(), message=message.encode(),
) )
self.in_processing.remove(message.message_id) self.in_processing.remove(message.message_id)
@ -233,9 +233,9 @@ class _PostgresConsumer(Consumer):
self.query_set.filter( self.query_set.filter(
message_id=message.message_id, message_id=message.message_id,
queue_name=message.queue_name, queue_name=message.queue_name,
state__ne=MQueue.State.REJECTED, state__ne=Task.State.REJECTED,
).update( ).update(
state=MQueue.State.REJECTED, state=Task.State.REJECTED,
message=message.encode(), message=message.encode(),
) )
self.in_processing.remove(message.message_id) self.in_processing.remove(message.message_id)
@ -245,14 +245,14 @@ class _PostgresConsumer(Consumer):
self.query_set.filter( self.query_set.filter(
message_id__in=[message.message_id for message in messages], message_id__in=[message.message_id for message in messages],
).update( ).update(
state=MQueue.State.QUEUED, state=Task.State.QUEUED,
) )
# We don't care about locks, requeue occurs on worker stop # We don't care about locks, requeue occurs on worker stop
def _fetch_pending_notifies(self) -> list[Notify]: def _fetch_pending_notifies(self) -> list[Notify]:
self.logger.debug(f"Polling for lost messages in {self.queue_name}") self.logger.debug(f"Polling for lost messages in {self.queue_name}")
notifies = self.query_set.filter( notifies = self.query_set.filter(
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), queue_name=self.queue_name state__in=(Task.State.QUEUED, Task.State.CONSUMED), queue_name=self.queue_name
) )
channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE) channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE)
return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies] return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies]
@ -276,9 +276,9 @@ class _PostgresConsumer(Consumer):
result = ( result = (
self.query_set.filter( self.query_set.filter(
message_id=message.message_id, message_id=message.message_id,
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), state__in=(Task.State.QUEUED, Task.State.CONSUMED),
) )
.update(state=MQueue.State.CONSUMED, mtime=timezone.now()) .update(state=Task.State.CONSUMED, mtime=timezone.now())
.extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)]) .extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)])
) )
return result == 1 return result == 1
@ -354,7 +354,7 @@ class _PostgresConsumer(Consumer):
return return
self.logger.debug("Running garbage collector") self.logger.debug("Running garbage collector")
count = self.query_set.filter( count = self.query_set.filter(
state__in=(MQueue.State.DONE, MQueue.State.REJECTED), state__in=(Task.State.DONE, Task.State.REJECTED),
mtime__lte=timezone.now() - timezone.timedelta(days=30), mtime__lte=timezone.now() - timezone.timedelta(days=30),
).delete() ).delete()
self.logger.info(f"Purged {count} messages in all queues") self.logger.info(f"Purged {count} messages in all queues")

View File

@ -0,0 +1,19 @@
import contextvars
from dramatiq.message import Message
from dramatiq.middleware import Middleware
from authentik.tasks.models import Task
class CurrentTask(Middleware):
_TASK: contextvars.ContextVar[Task | None] = contextvars.ContextVar("_TASK", default=None)
@classmethod
def get_task(cls) -> Task | None:
return cls._TASK.get()
def before_process_message(self, _, message: Message):
self._TASK.set(Task.objects.get(message_id=message.message_id))
def after_process_message(self, *args, **kwargs):
self._TASK.get().save()
self._TASK.set(None)

View File

@ -1,79 +0,0 @@
# Generated by Django 5.0.12 on 2025-03-09 00:47
import django.db.models.deletion
import django.utils.timezone
import pgtrigger.compiler
import pgtrigger.migrations
import uuid
from django.db import migrations, models
class Migration(migrations.Migration):
initial = True
dependencies = [
("authentik_tenants", "0004_tenant_impersonation_require_reason"),
]
operations = [
migrations.CreateModel(
name="Queue",
fields=[
(
"message_id",
models.UUIDField(
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
),
),
("queue_name", models.TextField(default="default")),
(
"state",
models.CharField(
choices=[
("queued", "Queued"),
("consumed", "Consumed"),
("rejected", "Rejected"),
("done", "Done"),
],
default="queued",
),
),
("mtime", models.DateTimeField(default=django.utils.timezone.now)),
("message", models.JSONField(blank=True, null=True)),
("result", models.JSONField(blank=True, null=True)),
("result_ttl", models.DateTimeField(blank=True, null=True)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="authentik_tenants.tenant"
),
),
],
options={
"indexes": [
models.Index(fields=["state", "mtime"], name="authentik_t_state_b7ff76_idx")
],
},
),
migrations.RunSQL(
"ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop
),
pgtrigger.migrations.AddTrigger(
model_name="queue",
trigger=pgtrigger.compiler.Trigger(
name="notify_enqueueing",
sql=pgtrigger.compiler.UpsertTriggerSql(
condition="WHEN (NEW.\"state\" = 'queued')",
constraint="CONSTRAINT",
func="\n PERFORM pg_notify(\n 'authentik.tasks' || NEW.queue_name || '.enqueue',\n CASE WHEN octet_length(NEW.message::text) >= 8000\n THEN jsonb_build_object('message_id', NEW.message_id)::text\n ELSE message::text\n END\n );\n RETURN NEW;\n ",
hash="d604c5647f3821f100e8aa7a52be181bde9ebdce",
operation="INSERT OR UPDATE",
pgid="pgtrigger_notify_enqueueing_b1977",
table="authentik_tasks_queue",
timing="DEFERRABLE INITIALLY DEFERRED",
when="AFTER",
),
),
),
]

View File

@ -5,6 +5,7 @@ import pgtrigger
from django.db import models from django.db import models
from django.utils import timezone from django.utils import timezone
from authentik.lib.models import SerializerModel
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
CHANNEL_PREFIX = "authentik.tasks" CHANNEL_PREFIX = "authentik.tasks"
@ -15,21 +16,23 @@ class ChannelIdentifier(StrEnum):
LOCK = auto() LOCK = auto()
class Queue(models.Model): class Task(SerializerModel):
class State(models.TextChoices): class State(models.TextChoices):
QUEUED = "queued" QUEUED = "queued"
CONSUMED = "consumed" CONSUMED = "consumed"
REJECTED = "rejected" REJECTED = "rejected"
DONE = "done" DONE = "done"
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
message_id = models.UUIDField(primary_key=True, default=uuid4, editable=False) message_id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
queue_name = models.TextField(default="default") tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE, editable=False)
state = models.CharField(default=State.QUEUED, choices=State.choices) queue_name = models.TextField(default="default", editable=False)
mtime = models.DateTimeField(default=timezone.now) state = models.CharField(default=State.QUEUED, choices=State.choices, editable=False)
message = models.JSONField(blank=True, null=True) mtime = models.DateTimeField(default=timezone.now, editable=False)
result = models.JSONField(blank=True, null=True) message = models.JSONField(blank=True, null=True, editable=False)
result_ttl = models.DateTimeField(blank=True, null=True) result = models.JSONField(blank=True, null=True, editable=False)
result_ttl = models.DateTimeField(blank=True, null=True, editable=False)
description = models.TextField(blank=True)
messages = models.JSONField(blank=True, null=True, editable=False)
class Meta: class Meta:
indexes = (models.Index(fields=("state", "mtime")),) indexes = (models.Index(fields=("state", "mtime")),)
@ -55,3 +58,8 @@ class Queue(models.Model):
def __str__(self): def __str__(self):
return str(self.message_id) return str(self.message_id)
@property
def serializer(self):
# TODO: fixme
pass

View File

@ -4,7 +4,7 @@ from django.utils import timezone
from dramatiq.message import Message, get_encoder from dramatiq.message import Message, get_encoder
from dramatiq.results.backend import Missing, MResult, Result, ResultBackend from dramatiq.results.backend import Missing, MResult, Result, ResultBackend
from authentik.tasks.models import Queue from authentik.tasks.models import Task
class PostgresBackend(ResultBackend): class PostgresBackend(ResultBackend):
@ -14,7 +14,7 @@ class PostgresBackend(ResultBackend):
@property @property
def query_set(self) -> QuerySet: def query_set(self) -> QuerySet:
return Queue.objects.using(self.db_alias) return Task.objects.using(self.db_alias)
def build_message_key(self, message: Message) -> str: def build_message_key(self, message: Message) -> str:
return str(message.message_id) return str(message.message_id)