From 2b1ee8cd5c16fcd6477ffb65dd4fe2e43b285e62 Mon Sep 17 00:00:00 2001 From: Marc 'risson' Schmitt Date: Sun, 9 Mar 2025 01:49:03 +0100 Subject: [PATCH] wip Signed-off-by: Marc 'risson' Schmitt --- authentik/root/settings.py | 1 + authentik/tasks/broker.py | 44 +++++++++++++--------- authentik/tasks/migrations/0001_initial.py | 32 +++++++++++++++- authentik/tasks/models.py | 28 ++++++++++++++ authentik/tasks/results.py | 21 ++++++++--- poetry.lock | 17 ++++++++- pyproject.toml | 1 + 7 files changed, 117 insertions(+), 27 deletions(-) diff --git a/authentik/root/settings.py b/authentik/root/settings.py index f8976ecd74..b20796ef8e 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -73,6 +73,7 @@ TENANT_APPS = [ "django.contrib.auth", "django.contrib.contenttypes", "django.contrib.sessions", + "pgtrigger", "authentik.admin", "authentik.api", "authentik.crypto", diff --git a/authentik/tasks/broker.py b/authentik/tasks/broker.py index d2c2c2b61e..b9fae8f935 100644 --- a/authentik/tasks/broker.py +++ b/authentik/tasks/broker.py @@ -1,4 +1,3 @@ -from enum import Enum, StrEnum, auto import functools import logging import time @@ -22,20 +21,15 @@ from psycopg import Notify from psycopg.errors import AdminShutdown from structlog.stdlib import get_logger -from authentik.tasks.models import Queue as MQueue +from authentik.tasks.models import Queue as MQueue, CHANNEL_PREFIX, ChannelIdentifier from authentik.tasks.results import PostgresBackend from authentik.tenants.utils import get_current_tenant LOGGER = get_logger() -class ChannelIdentifier(StrEnum): - ENQUEUE = auto() - LOCK = auto() - - def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str: - return f"authentik.tasks.{queue_name}.{identifier.value}" + return f"{CHANNEL_PREFIX}.{queue_name}.{identifier.value}" def raise_connection_error(func): @@ -62,7 +56,11 @@ class PostgresBroker(Broker): self.add_middleware(Results(backend=self.backend)) @property - def consumer_class(self): + def connection(self) -> DatabaseWrapper: + return connections[self.db_alias] + + @property + def consumer_class(self) -> "type[_PostgresConsumer]": return _PostgresConsumer @property @@ -120,14 +118,24 @@ class PostgresBroker(Broker): self.declare_queue(canonical_queue_name) self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}") self.emit_before("enqueue", message, delay) - # TODO: notify - # TODO: update_or_create - self.query_set.create( - message_id=message.message_id, - tenant=get_current_tenant(), - queue_name=message.queue_name, - state=MQueue.State.QUEUED, - message=message.encode(), + encoded = message.encode() + query = { + "message_id": message.message_id, + } + defaults = { + "tenant": get_current_tenant(), + "queue_name": message.queue_name, + "state": MQueue.State.QUEUED, + "message": encoded, + } + create_defaults = { + **query, + **defaults, + } + self.query_set.update_or_create( + **query, + defaults=defaults, + create_defaults=create_defaults, ) self.emit_after("enqueue", message, delay) return message @@ -227,7 +235,7 @@ class _PostgresConsumer(Consumer): queue_name=message.queue_name, state__ne=MQueue.State.REJECTED, ).update( - state=MQueue.State.REJECT, + state=MQueue.State.REJECTED, message=message.encode(), ) self.in_processing.remove(message.message_id) diff --git a/authentik/tasks/migrations/0001_initial.py b/authentik/tasks/migrations/0001_initial.py index c182cf0c2d..4e5bab19f0 100644 --- a/authentik/tasks/migrations/0001_initial.py +++ b/authentik/tasks/migrations/0001_initial.py @@ -1,6 +1,9 @@ -# Generated by Django 5.0.12 on 2025-03-08 22:39 +# Generated by Django 5.0.12 on 2025-03-09 00:47 +import django.db.models.deletion import django.utils.timezone +import pgtrigger.compiler +import pgtrigger.migrations import uuid from django.db import migrations, models @@ -9,7 +12,9 @@ class Migration(migrations.Migration): initial = True - dependencies = [] + dependencies = [ + ("authentik_tenants", "0004_tenant_impersonation_require_reason"), + ] operations = [ migrations.CreateModel( @@ -38,6 +43,12 @@ class Migration(migrations.Migration): ("message", models.JSONField(blank=True, null=True)), ("result", models.JSONField(blank=True, null=True)), ("result_ttl", models.DateTimeField(blank=True, null=True)), + ( + "tenant", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to="authentik_tenants.tenant" + ), + ), ], options={ "indexes": [ @@ -48,4 +59,21 @@ class Migration(migrations.Migration): migrations.RunSQL( "ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop ), + pgtrigger.migrations.AddTrigger( + model_name="queue", + trigger=pgtrigger.compiler.Trigger( + name="notify_enqueueing", + sql=pgtrigger.compiler.UpsertTriggerSql( + condition="WHEN (NEW.\"state\" = 'queued')", + constraint="CONSTRAINT", + func="\n PERFORM pg_notify(\n 'authentik.tasks' || NEW.queue_name || '.enqueue',\n CASE WHEN octet_length(NEW.message::text) >= 8000\n THEN jsonb_build_object('message_id', NEW.message_id)::text\n ELSE message::text\n END\n );\n RETURN NEW;\n ", + hash="d604c5647f3821f100e8aa7a52be181bde9ebdce", + operation="INSERT OR UPDATE", + pgid="pgtrigger_notify_enqueueing_b1977", + table="authentik_tasks_queue", + timing="DEFERRABLE INITIALLY DEFERRED", + when="AFTER", + ), + ), + ), ] diff --git a/authentik/tasks/models.py b/authentik/tasks/models.py index 37fc5201b4..1b0b93bf87 100644 --- a/authentik/tasks/models.py +++ b/authentik/tasks/models.py @@ -1,10 +1,19 @@ +from enum import StrEnum, auto from uuid import uuid4 +import pgtrigger from django.db import models from django.utils import timezone from authentik.tenants.models import Tenant +CHANNEL_PREFIX = "authentik.tasks" + + +class ChannelIdentifier(StrEnum): + ENQUEUE = auto() + LOCK = auto() + class Queue(models.Model): class State(models.TextChoices): @@ -24,6 +33,25 @@ class Queue(models.Model): class Meta: indexes = (models.Index(fields=("state", "mtime")),) + triggers = ( + pgtrigger.Trigger( + name="notify_enqueueing", + operation=pgtrigger.Insert | pgtrigger.Update, + when=pgtrigger.After, + condition=pgtrigger.Q(new__state="queued"), + timing=pgtrigger.Deferred, + func=f""" + PERFORM pg_notify( + '{CHANNEL_PREFIX}' || NEW.queue_name || '.{ChannelIdentifier.ENQUEUE.value}', + CASE WHEN octet_length(NEW.message::text) >= 8000 + THEN jsonb_build_object('message_id', NEW.message_id)::text + ELSE message::text + END + ); + RETURN NEW; + """, # noqa: E501 + ), + ) def __str__(self): return str(self.message_id) diff --git a/authentik/tasks/results.py b/authentik/tasks/results.py index 8e014fd396..2c09374585 100644 --- a/authentik/tasks/results.py +++ b/authentik/tasks/results.py @@ -29,10 +29,19 @@ class PostgresBackend(ResultBackend): return self.encoder.decode(data) def _store(self, message_key: str, result: Result, ttl: int) -> None: - # TODO: update_or_create encoder = get_encoder() - self.query_set.filter(message_id=message_key).update( - mtime=timezone.now(), - result=encoder.encode(result), - result_ttl=timezone.now() + timezone.timedelta(milliseconds=ttl), - ) + query = { + "message_id": message_key, + } + defaults = { + "mtime": timezone.now(), + "result": encoder.encode(result), + "result_ttl": timezone.now() + timezone.timedelta(milliseconds=ttl), + } + create_defaults = { + **query, + **defaults, + "queue_name": "__RQ__", + "state": Queue.State.DONE, + } + self.query_set.update_or_create(**query, defaults=defaults, create_defaults=create_defaults) diff --git a/poetry.lock b/poetry.lock index 28b517af5d..e484a8870d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1534,6 +1534,21 @@ files = [ django = ">=4" django-pgactivity = ">=1.2,<2" +[[package]] +name = "django-pgtrigger" +version = "4.13.3" +description = "Postgres trigger support integrated with Django models." +category = "main" +optional = false +python-versions = "<4,>=3.9.0" +files = [ + {file = "django_pgtrigger-4.13.3-py3-none-any.whl", hash = "sha256:d6e4d17021bbd5e425a308f07414b237b9b34423275d86ad756b90c307df3ca4"}, + {file = "django_pgtrigger-4.13.3.tar.gz", hash = "sha256:c525f9e81f120d166c4bd5fe8c3770640356f0644edf0fc2b7f6426008e52f77"}, +] + +[package.dependencies] +django = ">=4" + [[package]] name = "django-prometheus" version = "2.3.1" @@ -6341,4 +6356,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "~3.12" -content-hash = "2c4a0e7ad08377212eb8ef2c3a372a557a407c43f2e2b3429403598477106c55" +content-hash = "1f8cfdb90051911b2d677a9fd9ad4de8afea6fd8644de4bc25d418652506b7e2" diff --git a/pyproject.toml b/pyproject.toml index ca736cc8ed..4a2f5b2c95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,6 +124,7 @@ django-filter = "*" django-guardian = "*" django-model-utils = "*" django-pglock = "*" +django-pgtrigger = "*" django-prometheus = "*" django-redis = "*" django-storages = { extras = ["s3"], version = "*" }