Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-03-09 01:49:03 +01:00
parent e8cfc2b91e
commit 2b1ee8cd5c
7 changed files with 117 additions and 27 deletions

View File

@ -73,6 +73,7 @@ TENANT_APPS = [
"django.contrib.auth", "django.contrib.auth",
"django.contrib.contenttypes", "django.contrib.contenttypes",
"django.contrib.sessions", "django.contrib.sessions",
"pgtrigger",
"authentik.admin", "authentik.admin",
"authentik.api", "authentik.api",
"authentik.crypto", "authentik.crypto",

View File

@ -1,4 +1,3 @@
from enum import Enum, StrEnum, auto
import functools import functools
import logging import logging
import time import time
@ -22,20 +21,15 @@ from psycopg import Notify
from psycopg.errors import AdminShutdown from psycopg.errors import AdminShutdown
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.tasks.models import Queue as MQueue from authentik.tasks.models import Queue as MQueue, CHANNEL_PREFIX, ChannelIdentifier
from authentik.tasks.results import PostgresBackend from authentik.tasks.results import PostgresBackend
from authentik.tenants.utils import get_current_tenant from authentik.tenants.utils import get_current_tenant
LOGGER = get_logger() LOGGER = get_logger()
class ChannelIdentifier(StrEnum):
ENQUEUE = auto()
LOCK = auto()
def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str: def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
return f"authentik.tasks.{queue_name}.{identifier.value}" return f"{CHANNEL_PREFIX}.{queue_name}.{identifier.value}"
def raise_connection_error(func): def raise_connection_error(func):
@ -62,7 +56,11 @@ class PostgresBroker(Broker):
self.add_middleware(Results(backend=self.backend)) self.add_middleware(Results(backend=self.backend))
@property @property
def consumer_class(self): def connection(self) -> DatabaseWrapper:
return connections[self.db_alias]
@property
def consumer_class(self) -> "type[_PostgresConsumer]":
return _PostgresConsumer return _PostgresConsumer
@property @property
@ -120,14 +118,24 @@ class PostgresBroker(Broker):
self.declare_queue(canonical_queue_name) self.declare_queue(canonical_queue_name)
self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}") self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}")
self.emit_before("enqueue", message, delay) self.emit_before("enqueue", message, delay)
# TODO: notify encoded = message.encode()
# TODO: update_or_create query = {
self.query_set.create( "message_id": message.message_id,
message_id=message.message_id, }
tenant=get_current_tenant(), defaults = {
queue_name=message.queue_name, "tenant": get_current_tenant(),
state=MQueue.State.QUEUED, "queue_name": message.queue_name,
message=message.encode(), "state": MQueue.State.QUEUED,
"message": encoded,
}
create_defaults = {
**query,
**defaults,
}
self.query_set.update_or_create(
**query,
defaults=defaults,
create_defaults=create_defaults,
) )
self.emit_after("enqueue", message, delay) self.emit_after("enqueue", message, delay)
return message return message
@ -227,7 +235,7 @@ class _PostgresConsumer(Consumer):
queue_name=message.queue_name, queue_name=message.queue_name,
state__ne=MQueue.State.REJECTED, state__ne=MQueue.State.REJECTED,
).update( ).update(
state=MQueue.State.REJECT, state=MQueue.State.REJECTED,
message=message.encode(), message=message.encode(),
) )
self.in_processing.remove(message.message_id) self.in_processing.remove(message.message_id)

View File

@ -1,6 +1,9 @@
# Generated by Django 5.0.12 on 2025-03-08 22:39 # Generated by Django 5.0.12 on 2025-03-09 00:47
import django.db.models.deletion
import django.utils.timezone import django.utils.timezone
import pgtrigger.compiler
import pgtrigger.migrations
import uuid import uuid
from django.db import migrations, models from django.db import migrations, models
@ -9,7 +12,9 @@ class Migration(migrations.Migration):
initial = True initial = True
dependencies = [] dependencies = [
("authentik_tenants", "0004_tenant_impersonation_require_reason"),
]
operations = [ operations = [
migrations.CreateModel( migrations.CreateModel(
@ -38,6 +43,12 @@ class Migration(migrations.Migration):
("message", models.JSONField(blank=True, null=True)), ("message", models.JSONField(blank=True, null=True)),
("result", models.JSONField(blank=True, null=True)), ("result", models.JSONField(blank=True, null=True)),
("result_ttl", models.DateTimeField(blank=True, null=True)), ("result_ttl", models.DateTimeField(blank=True, null=True)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="authentik_tenants.tenant"
),
),
], ],
options={ options={
"indexes": [ "indexes": [
@ -48,4 +59,21 @@ class Migration(migrations.Migration):
migrations.RunSQL( migrations.RunSQL(
"ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop "ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop
), ),
pgtrigger.migrations.AddTrigger(
model_name="queue",
trigger=pgtrigger.compiler.Trigger(
name="notify_enqueueing",
sql=pgtrigger.compiler.UpsertTriggerSql(
condition="WHEN (NEW.\"state\" = 'queued')",
constraint="CONSTRAINT",
func="\n PERFORM pg_notify(\n 'authentik.tasks' || NEW.queue_name || '.enqueue',\n CASE WHEN octet_length(NEW.message::text) >= 8000\n THEN jsonb_build_object('message_id', NEW.message_id)::text\n ELSE message::text\n END\n );\n RETURN NEW;\n ",
hash="d604c5647f3821f100e8aa7a52be181bde9ebdce",
operation="INSERT OR UPDATE",
pgid="pgtrigger_notify_enqueueing_b1977",
table="authentik_tasks_queue",
timing="DEFERRABLE INITIALLY DEFERRED",
when="AFTER",
),
),
),
] ]

View File

@ -1,10 +1,19 @@
from enum import StrEnum, auto
from uuid import uuid4 from uuid import uuid4
import pgtrigger
from django.db import models from django.db import models
from django.utils import timezone from django.utils import timezone
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
CHANNEL_PREFIX = "authentik.tasks"
class ChannelIdentifier(StrEnum):
ENQUEUE = auto()
LOCK = auto()
class Queue(models.Model): class Queue(models.Model):
class State(models.TextChoices): class State(models.TextChoices):
@ -24,6 +33,25 @@ class Queue(models.Model):
class Meta: class Meta:
indexes = (models.Index(fields=("state", "mtime")),) indexes = (models.Index(fields=("state", "mtime")),)
triggers = (
pgtrigger.Trigger(
name="notify_enqueueing",
operation=pgtrigger.Insert | pgtrigger.Update,
when=pgtrigger.After,
condition=pgtrigger.Q(new__state="queued"),
timing=pgtrigger.Deferred,
func=f"""
PERFORM pg_notify(
'{CHANNEL_PREFIX}' || NEW.queue_name || '.{ChannelIdentifier.ENQUEUE.value}',
CASE WHEN octet_length(NEW.message::text) >= 8000
THEN jsonb_build_object('message_id', NEW.message_id)::text
ELSE message::text
END
);
RETURN NEW;
""", # noqa: E501
),
)
def __str__(self): def __str__(self):
return str(self.message_id) return str(self.message_id)

View File

@ -29,10 +29,19 @@ class PostgresBackend(ResultBackend):
return self.encoder.decode(data) return self.encoder.decode(data)
def _store(self, message_key: str, result: Result, ttl: int) -> None: def _store(self, message_key: str, result: Result, ttl: int) -> None:
# TODO: update_or_create
encoder = get_encoder() encoder = get_encoder()
self.query_set.filter(message_id=message_key).update( query = {
mtime=timezone.now(), "message_id": message_key,
result=encoder.encode(result), }
result_ttl=timezone.now() + timezone.timedelta(milliseconds=ttl), defaults = {
) "mtime": timezone.now(),
"result": encoder.encode(result),
"result_ttl": timezone.now() + timezone.timedelta(milliseconds=ttl),
}
create_defaults = {
**query,
**defaults,
"queue_name": "__RQ__",
"state": Queue.State.DONE,
}
self.query_set.update_or_create(**query, defaults=defaults, create_defaults=create_defaults)

17
poetry.lock generated
View File

@ -1534,6 +1534,21 @@ files = [
django = ">=4" django = ">=4"
django-pgactivity = ">=1.2,<2" django-pgactivity = ">=1.2,<2"
[[package]]
name = "django-pgtrigger"
version = "4.13.3"
description = "Postgres trigger support integrated with Django models."
category = "main"
optional = false
python-versions = "<4,>=3.9.0"
files = [
{file = "django_pgtrigger-4.13.3-py3-none-any.whl", hash = "sha256:d6e4d17021bbd5e425a308f07414b237b9b34423275d86ad756b90c307df3ca4"},
{file = "django_pgtrigger-4.13.3.tar.gz", hash = "sha256:c525f9e81f120d166c4bd5fe8c3770640356f0644edf0fc2b7f6426008e52f77"},
]
[package.dependencies]
django = ">=4"
[[package]] [[package]]
name = "django-prometheus" name = "django-prometheus"
version = "2.3.1" version = "2.3.1"
@ -6341,4 +6356,4 @@ files = [
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "~3.12" python-versions = "~3.12"
content-hash = "2c4a0e7ad08377212eb8ef2c3a372a557a407c43f2e2b3429403598477106c55" content-hash = "1f8cfdb90051911b2d677a9fd9ad4de8afea6fd8644de4bc25d418652506b7e2"

View File

@ -124,6 +124,7 @@ django-filter = "*"
django-guardian = "*" django-guardian = "*"
django-model-utils = "*" django-model-utils = "*"
django-pglock = "*" django-pglock = "*"
django-pgtrigger = "*"
django-prometheus = "*" django-prometheus = "*"
django-redis = "*" django-redis = "*"
django-storages = { extras = ["s3"], version = "*" } django-storages = { extras = ["s3"], version = "*" }