Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-03-09 01:49:03 +01:00
parent e8cfc2b91e
commit 2b1ee8cd5c
7 changed files with 117 additions and 27 deletions

View File

@ -73,6 +73,7 @@ TENANT_APPS = [
"django.contrib.auth",
"django.contrib.contenttypes",
"django.contrib.sessions",
"pgtrigger",
"authentik.admin",
"authentik.api",
"authentik.crypto",

View File

@ -1,4 +1,3 @@
from enum import Enum, StrEnum, auto
import functools
import logging
import time
@ -22,20 +21,15 @@ from psycopg import Notify
from psycopg.errors import AdminShutdown
from structlog.stdlib import get_logger
from authentik.tasks.models import Queue as MQueue
from authentik.tasks.models import Queue as MQueue, CHANNEL_PREFIX, ChannelIdentifier
from authentik.tasks.results import PostgresBackend
from authentik.tenants.utils import get_current_tenant
LOGGER = get_logger()
class ChannelIdentifier(StrEnum):
ENQUEUE = auto()
LOCK = auto()
def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
return f"authentik.tasks.{queue_name}.{identifier.value}"
return f"{CHANNEL_PREFIX}.{queue_name}.{identifier.value}"
def raise_connection_error(func):
@ -62,7 +56,11 @@ class PostgresBroker(Broker):
self.add_middleware(Results(backend=self.backend))
@property
def consumer_class(self):
def connection(self) -> DatabaseWrapper:
return connections[self.db_alias]
@property
def consumer_class(self) -> "type[_PostgresConsumer]":
return _PostgresConsumer
@property
@ -120,14 +118,24 @@ class PostgresBroker(Broker):
self.declare_queue(canonical_queue_name)
self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}")
self.emit_before("enqueue", message, delay)
# TODO: notify
# TODO: update_or_create
self.query_set.create(
message_id=message.message_id,
tenant=get_current_tenant(),
queue_name=message.queue_name,
state=MQueue.State.QUEUED,
message=message.encode(),
encoded = message.encode()
query = {
"message_id": message.message_id,
}
defaults = {
"tenant": get_current_tenant(),
"queue_name": message.queue_name,
"state": MQueue.State.QUEUED,
"message": encoded,
}
create_defaults = {
**query,
**defaults,
}
self.query_set.update_or_create(
**query,
defaults=defaults,
create_defaults=create_defaults,
)
self.emit_after("enqueue", message, delay)
return message
@ -227,7 +235,7 @@ class _PostgresConsumer(Consumer):
queue_name=message.queue_name,
state__ne=MQueue.State.REJECTED,
).update(
state=MQueue.State.REJECT,
state=MQueue.State.REJECTED,
message=message.encode(),
)
self.in_processing.remove(message.message_id)

View File

@ -1,6 +1,9 @@
# Generated by Django 5.0.12 on 2025-03-08 22:39
# Generated by Django 5.0.12 on 2025-03-09 00:47
import django.db.models.deletion
import django.utils.timezone
import pgtrigger.compiler
import pgtrigger.migrations
import uuid
from django.db import migrations, models
@ -9,7 +12,9 @@ class Migration(migrations.Migration):
initial = True
dependencies = []
dependencies = [
("authentik_tenants", "0004_tenant_impersonation_require_reason"),
]
operations = [
migrations.CreateModel(
@ -38,6 +43,12 @@ class Migration(migrations.Migration):
("message", models.JSONField(blank=True, null=True)),
("result", models.JSONField(blank=True, null=True)),
("result_ttl", models.DateTimeField(blank=True, null=True)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="authentik_tenants.tenant"
),
),
],
options={
"indexes": [
@ -48,4 +59,21 @@ class Migration(migrations.Migration):
migrations.RunSQL(
"ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop
),
pgtrigger.migrations.AddTrigger(
model_name="queue",
trigger=pgtrigger.compiler.Trigger(
name="notify_enqueueing",
sql=pgtrigger.compiler.UpsertTriggerSql(
condition="WHEN (NEW.\"state\" = 'queued')",
constraint="CONSTRAINT",
func="\n PERFORM pg_notify(\n 'authentik.tasks' || NEW.queue_name || '.enqueue',\n CASE WHEN octet_length(NEW.message::text) >= 8000\n THEN jsonb_build_object('message_id', NEW.message_id)::text\n ELSE message::text\n END\n );\n RETURN NEW;\n ",
hash="d604c5647f3821f100e8aa7a52be181bde9ebdce",
operation="INSERT OR UPDATE",
pgid="pgtrigger_notify_enqueueing_b1977",
table="authentik_tasks_queue",
timing="DEFERRABLE INITIALLY DEFERRED",
when="AFTER",
),
),
),
]

View File

@ -1,10 +1,19 @@
from enum import StrEnum, auto
from uuid import uuid4
import pgtrigger
from django.db import models
from django.utils import timezone
from authentik.tenants.models import Tenant
CHANNEL_PREFIX = "authentik.tasks"
class ChannelIdentifier(StrEnum):
ENQUEUE = auto()
LOCK = auto()
class Queue(models.Model):
class State(models.TextChoices):
@ -24,6 +33,25 @@ class Queue(models.Model):
class Meta:
indexes = (models.Index(fields=("state", "mtime")),)
triggers = (
pgtrigger.Trigger(
name="notify_enqueueing",
operation=pgtrigger.Insert | pgtrigger.Update,
when=pgtrigger.After,
condition=pgtrigger.Q(new__state="queued"),
timing=pgtrigger.Deferred,
func=f"""
PERFORM pg_notify(
'{CHANNEL_PREFIX}' || NEW.queue_name || '.{ChannelIdentifier.ENQUEUE.value}',
CASE WHEN octet_length(NEW.message::text) >= 8000
THEN jsonb_build_object('message_id', NEW.message_id)::text
ELSE message::text
END
);
RETURN NEW;
""", # noqa: E501
),
)
def __str__(self):
return str(self.message_id)

View File

@ -29,10 +29,19 @@ class PostgresBackend(ResultBackend):
return self.encoder.decode(data)
def _store(self, message_key: str, result: Result, ttl: int) -> None:
# TODO: update_or_create
encoder = get_encoder()
self.query_set.filter(message_id=message_key).update(
mtime=timezone.now(),
result=encoder.encode(result),
result_ttl=timezone.now() + timezone.timedelta(milliseconds=ttl),
)
query = {
"message_id": message_key,
}
defaults = {
"mtime": timezone.now(),
"result": encoder.encode(result),
"result_ttl": timezone.now() + timezone.timedelta(milliseconds=ttl),
}
create_defaults = {
**query,
**defaults,
"queue_name": "__RQ__",
"state": Queue.State.DONE,
}
self.query_set.update_or_create(**query, defaults=defaults, create_defaults=create_defaults)

17
poetry.lock generated
View File

@ -1534,6 +1534,21 @@ files = [
django = ">=4"
django-pgactivity = ">=1.2,<2"
[[package]]
name = "django-pgtrigger"
version = "4.13.3"
description = "Postgres trigger support integrated with Django models."
category = "main"
optional = false
python-versions = "<4,>=3.9.0"
files = [
{file = "django_pgtrigger-4.13.3-py3-none-any.whl", hash = "sha256:d6e4d17021bbd5e425a308f07414b237b9b34423275d86ad756b90c307df3ca4"},
{file = "django_pgtrigger-4.13.3.tar.gz", hash = "sha256:c525f9e81f120d166c4bd5fe8c3770640356f0644edf0fc2b7f6426008e52f77"},
]
[package.dependencies]
django = ">=4"
[[package]]
name = "django-prometheus"
version = "2.3.1"
@ -6341,4 +6356,4 @@ files = [
[metadata]
lock-version = "2.0"
python-versions = "~3.12"
content-hash = "2c4a0e7ad08377212eb8ef2c3a372a557a407c43f2e2b3429403598477106c55"
content-hash = "1f8cfdb90051911b2d677a9fd9ad4de8afea6fd8644de4bc25d418652506b7e2"

View File

@ -124,6 +124,7 @@ django-filter = "*"
django-guardian = "*"
django-model-utils = "*"
django-pglock = "*"
django-pgtrigger = "*"
django-prometheus = "*"
django-redis = "*"
django-storages = { extras = ["s3"], version = "*" }