From c4b988c6327b33b166dbab3cb5d6c57427e27c22 Mon Sep 17 00:00:00 2001 From: Marc 'risson' Schmitt Date: Sun, 9 Mar 2025 02:30:38 +0100 Subject: [PATCH] wip Signed-off-by: Marc 'risson' Schmitt --- authentik/tasks/apps.py | 2 + authentik/tasks/broker.py | 28 ++++---- authentik/tasks/middleware.py | 19 ++++++ authentik/tasks/migrations/0001_initial.py | 79 ---------------------- authentik/tasks/models.py | 24 ++++--- authentik/tasks/results.py | 4 +- 6 files changed, 53 insertions(+), 103 deletions(-) create mode 100644 authentik/tasks/middleware.py delete mode 100644 authentik/tasks/migrations/0001_initial.py diff --git a/authentik/tasks/apps.py b/authentik/tasks/apps.py index 246dbae274..4061b0af6c 100644 --- a/authentik/tasks/apps.py +++ b/authentik/tasks/apps.py @@ -13,6 +13,7 @@ class AuthentikTasksConfig(ManagedAppConfig): def ready(self) -> None: from authentik.tasks.broker import PostgresBroker + from authentik.tasks.middleware import CurrentTask dramatiq.set_encoder(JSONPickleEncoder()) broker = PostgresBroker() @@ -21,5 +22,6 @@ class AuthentikTasksConfig(ManagedAppConfig): broker.add_middleware(TimeLimit()) broker.add_middleware(Callbacks()) broker.add_middleware(Retries(max_retries=3)) + broker.add_middleware(CurrentTask()) dramatiq.set_broker(broker) return super().ready() diff --git a/authentik/tasks/broker.py b/authentik/tasks/broker.py index b9fae8f935..6028820c6a 100644 --- a/authentik/tasks/broker.py +++ b/authentik/tasks/broker.py @@ -21,7 +21,7 @@ from psycopg import Notify from psycopg.errors import AdminShutdown from structlog.stdlib import get_logger -from authentik.tasks.models import Queue as MQueue, CHANNEL_PREFIX, ChannelIdentifier +from authentik.tasks.models import Task, CHANNEL_PREFIX, ChannelIdentifier from authentik.tasks.results import PostgresBackend from authentik.tenants.utils import get_current_tenant @@ -65,7 +65,7 @@ class PostgresBroker(Broker): @property def query_set(self) -> QuerySet: - return MQueue.objects.using(self.db_alias) + return Task.objects.using(self.db_alias) def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer: self.declare_queue(queue_name) @@ -125,7 +125,7 @@ class PostgresBroker(Broker): defaults = { "tenant": get_current_tenant(), "queue_name": message.queue_name, - "state": MQueue.State.QUEUED, + "state": Task.State.QUEUED, "message": encoded, } create_defaults = { @@ -169,7 +169,7 @@ class PostgresBroker(Broker): if ( self.query_set.filter( queue_name=queue_name, - state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), + state__in=(Task.State.QUEUED, Task.State.CONSUMED), ) == 0 ): @@ -200,7 +200,7 @@ class _PostgresConsumer(Consumer): @property def query_set(self) -> QuerySet: - return MQueue.objects.using(self.db_alias) + return Task.objects.using(self.db_alias) @property def listen_connection(self) -> DatabaseWrapper: @@ -220,9 +220,9 @@ class _PostgresConsumer(Consumer): self.query_set.filter( message_id=message.message_id, queue_name=message.queue_name, - state=MQueue.State.CONSUMED, + state=Task.State.CONSUMED, ).update( - state=MQueue.State.DONE, + state=Task.State.DONE, message=message.encode(), ) self.in_processing.remove(message.message_id) @@ -233,9 +233,9 @@ class _PostgresConsumer(Consumer): self.query_set.filter( message_id=message.message_id, queue_name=message.queue_name, - state__ne=MQueue.State.REJECTED, + state__ne=Task.State.REJECTED, ).update( - state=MQueue.State.REJECTED, + state=Task.State.REJECTED, message=message.encode(), ) self.in_processing.remove(message.message_id) @@ -245,14 +245,14 @@ class _PostgresConsumer(Consumer): self.query_set.filter( message_id__in=[message.message_id for message in messages], ).update( - state=MQueue.State.QUEUED, + state=Task.State.QUEUED, ) # We don't care about locks, requeue occurs on worker stop def _fetch_pending_notifies(self) -> list[Notify]: self.logger.debug(f"Polling for lost messages in {self.queue_name}") notifies = self.query_set.filter( - state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), queue_name=self.queue_name + state__in=(Task.State.QUEUED, Task.State.CONSUMED), queue_name=self.queue_name ) channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE) return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies] @@ -276,9 +276,9 @@ class _PostgresConsumer(Consumer): result = ( self.query_set.filter( message_id=message.message_id, - state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), + state__in=(Task.State.QUEUED, Task.State.CONSUMED), ) - .update(state=MQueue.State.CONSUMED, mtime=timezone.now()) + .update(state=Task.State.CONSUMED, mtime=timezone.now()) .extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)]) ) return result == 1 @@ -354,7 +354,7 @@ class _PostgresConsumer(Consumer): return self.logger.debug("Running garbage collector") count = self.query_set.filter( - state__in=(MQueue.State.DONE, MQueue.State.REJECTED), + state__in=(Task.State.DONE, Task.State.REJECTED), mtime__lte=timezone.now() - timezone.timedelta(days=30), ).delete() self.logger.info(f"Purged {count} messages in all queues") diff --git a/authentik/tasks/middleware.py b/authentik/tasks/middleware.py new file mode 100644 index 0000000000..1c1c9de30d --- /dev/null +++ b/authentik/tasks/middleware.py @@ -0,0 +1,19 @@ +import contextvars +from dramatiq.message import Message +from dramatiq.middleware import Middleware +from authentik.tasks.models import Task + + +class CurrentTask(Middleware): + _TASK: contextvars.ContextVar[Task | None] = contextvars.ContextVar("_TASK", default=None) + + @classmethod + def get_task(cls) -> Task | None: + return cls._TASK.get() + + def before_process_message(self, _, message: Message): + self._TASK.set(Task.objects.get(message_id=message.message_id)) + + def after_process_message(self, *args, **kwargs): + self._TASK.get().save() + self._TASK.set(None) diff --git a/authentik/tasks/migrations/0001_initial.py b/authentik/tasks/migrations/0001_initial.py deleted file mode 100644 index 4e5bab19f0..0000000000 --- a/authentik/tasks/migrations/0001_initial.py +++ /dev/null @@ -1,79 +0,0 @@ -# Generated by Django 5.0.12 on 2025-03-09 00:47 - -import django.db.models.deletion -import django.utils.timezone -import pgtrigger.compiler -import pgtrigger.migrations -import uuid -from django.db import migrations, models - - -class Migration(migrations.Migration): - - initial = True - - dependencies = [ - ("authentik_tenants", "0004_tenant_impersonation_require_reason"), - ] - - operations = [ - migrations.CreateModel( - name="Queue", - fields=[ - ( - "message_id", - models.UUIDField( - default=uuid.uuid4, editable=False, primary_key=True, serialize=False - ), - ), - ("queue_name", models.TextField(default="default")), - ( - "state", - models.CharField( - choices=[ - ("queued", "Queued"), - ("consumed", "Consumed"), - ("rejected", "Rejected"), - ("done", "Done"), - ], - default="queued", - ), - ), - ("mtime", models.DateTimeField(default=django.utils.timezone.now)), - ("message", models.JSONField(blank=True, null=True)), - ("result", models.JSONField(blank=True, null=True)), - ("result_ttl", models.DateTimeField(blank=True, null=True)), - ( - "tenant", - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, to="authentik_tenants.tenant" - ), - ), - ], - options={ - "indexes": [ - models.Index(fields=["state", "mtime"], name="authentik_t_state_b7ff76_idx") - ], - }, - ), - migrations.RunSQL( - "ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop - ), - pgtrigger.migrations.AddTrigger( - model_name="queue", - trigger=pgtrigger.compiler.Trigger( - name="notify_enqueueing", - sql=pgtrigger.compiler.UpsertTriggerSql( - condition="WHEN (NEW.\"state\" = 'queued')", - constraint="CONSTRAINT", - func="\n PERFORM pg_notify(\n 'authentik.tasks' || NEW.queue_name || '.enqueue',\n CASE WHEN octet_length(NEW.message::text) >= 8000\n THEN jsonb_build_object('message_id', NEW.message_id)::text\n ELSE message::text\n END\n );\n RETURN NEW;\n ", - hash="d604c5647f3821f100e8aa7a52be181bde9ebdce", - operation="INSERT OR UPDATE", - pgid="pgtrigger_notify_enqueueing_b1977", - table="authentik_tasks_queue", - timing="DEFERRABLE INITIALLY DEFERRED", - when="AFTER", - ), - ), - ), - ] diff --git a/authentik/tasks/models.py b/authentik/tasks/models.py index 1b0b93bf87..b081c44ec7 100644 --- a/authentik/tasks/models.py +++ b/authentik/tasks/models.py @@ -5,6 +5,7 @@ import pgtrigger from django.db import models from django.utils import timezone +from authentik.lib.models import SerializerModel from authentik.tenants.models import Tenant CHANNEL_PREFIX = "authentik.tasks" @@ -15,21 +16,23 @@ class ChannelIdentifier(StrEnum): LOCK = auto() -class Queue(models.Model): +class Task(SerializerModel): class State(models.TextChoices): QUEUED = "queued" CONSUMED = "consumed" REJECTED = "rejected" DONE = "done" - tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE) message_id = models.UUIDField(primary_key=True, default=uuid4, editable=False) - queue_name = models.TextField(default="default") - state = models.CharField(default=State.QUEUED, choices=State.choices) - mtime = models.DateTimeField(default=timezone.now) - message = models.JSONField(blank=True, null=True) - result = models.JSONField(blank=True, null=True) - result_ttl = models.DateTimeField(blank=True, null=True) + tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE, editable=False) + queue_name = models.TextField(default="default", editable=False) + state = models.CharField(default=State.QUEUED, choices=State.choices, editable=False) + mtime = models.DateTimeField(default=timezone.now, editable=False) + message = models.JSONField(blank=True, null=True, editable=False) + result = models.JSONField(blank=True, null=True, editable=False) + result_ttl = models.DateTimeField(blank=True, null=True, editable=False) + description = models.TextField(blank=True) + messages = models.JSONField(blank=True, null=True, editable=False) class Meta: indexes = (models.Index(fields=("state", "mtime")),) @@ -55,3 +58,8 @@ class Queue(models.Model): def __str__(self): return str(self.message_id) + + @property + def serializer(self): + # TODO: fixme + pass diff --git a/authentik/tasks/results.py b/authentik/tasks/results.py index 2c09374585..774f5e0969 100644 --- a/authentik/tasks/results.py +++ b/authentik/tasks/results.py @@ -4,7 +4,7 @@ from django.utils import timezone from dramatiq.message import Message, get_encoder from dramatiq.results.backend import Missing, MResult, Result, ResultBackend -from authentik.tasks.models import Queue +from authentik.tasks.models import Task class PostgresBackend(ResultBackend): @@ -14,7 +14,7 @@ class PostgresBackend(ResultBackend): @property def query_set(self) -> QuerySet: - return Queue.objects.using(self.db_alias) + return Task.objects.using(self.db_alias) def build_message_key(self, message: Message) -> str: return str(message.message_id)