Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-03-09 01:49:03 +01:00
parent e8cfc2b91e
commit 2b1ee8cd5c
7 changed files with 117 additions and 27 deletions

View File

@ -1,4 +1,3 @@
from enum import Enum, StrEnum, auto
import functools
import logging
import time
@ -22,20 +21,15 @@ from psycopg import Notify
from psycopg.errors import AdminShutdown
from structlog.stdlib import get_logger
from authentik.tasks.models import Queue as MQueue
from authentik.tasks.models import Queue as MQueue, CHANNEL_PREFIX, ChannelIdentifier
from authentik.tasks.results import PostgresBackend
from authentik.tenants.utils import get_current_tenant
LOGGER = get_logger()
class ChannelIdentifier(StrEnum):
ENQUEUE = auto()
LOCK = auto()
def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
return f"authentik.tasks.{queue_name}.{identifier.value}"
return f"{CHANNEL_PREFIX}.{queue_name}.{identifier.value}"
def raise_connection_error(func):
@ -62,7 +56,11 @@ class PostgresBroker(Broker):
self.add_middleware(Results(backend=self.backend))
@property
def consumer_class(self):
def connection(self) -> DatabaseWrapper:
return connections[self.db_alias]
@property
def consumer_class(self) -> "type[_PostgresConsumer]":
return _PostgresConsumer
@property
@ -120,14 +118,24 @@ class PostgresBroker(Broker):
self.declare_queue(canonical_queue_name)
self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}")
self.emit_before("enqueue", message, delay)
# TODO: notify
# TODO: update_or_create
self.query_set.create(
message_id=message.message_id,
tenant=get_current_tenant(),
queue_name=message.queue_name,
state=MQueue.State.QUEUED,
message=message.encode(),
encoded = message.encode()
query = {
"message_id": message.message_id,
}
defaults = {
"tenant": get_current_tenant(),
"queue_name": message.queue_name,
"state": MQueue.State.QUEUED,
"message": encoded,
}
create_defaults = {
**query,
**defaults,
}
self.query_set.update_or_create(
**query,
defaults=defaults,
create_defaults=create_defaults,
)
self.emit_after("enqueue", message, delay)
return message
@ -227,7 +235,7 @@ class _PostgresConsumer(Consumer):
queue_name=message.queue_name,
state__ne=MQueue.State.REJECTED,
).update(
state=MQueue.State.REJECT,
state=MQueue.State.REJECTED,
message=message.encode(),
)
self.in_processing.remove(message.message_id)

View File

@ -1,6 +1,9 @@
# Generated by Django 5.0.12 on 2025-03-08 22:39
# Generated by Django 5.0.12 on 2025-03-09 00:47
import django.db.models.deletion
import django.utils.timezone
import pgtrigger.compiler
import pgtrigger.migrations
import uuid
from django.db import migrations, models
@ -9,7 +12,9 @@ class Migration(migrations.Migration):
initial = True
dependencies = []
dependencies = [
("authentik_tenants", "0004_tenant_impersonation_require_reason"),
]
operations = [
migrations.CreateModel(
@ -38,6 +43,12 @@ class Migration(migrations.Migration):
("message", models.JSONField(blank=True, null=True)),
("result", models.JSONField(blank=True, null=True)),
("result_ttl", models.DateTimeField(blank=True, null=True)),
(
"tenant",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to="authentik_tenants.tenant"
),
),
],
options={
"indexes": [
@ -48,4 +59,21 @@ class Migration(migrations.Migration):
migrations.RunSQL(
"ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop
),
pgtrigger.migrations.AddTrigger(
model_name="queue",
trigger=pgtrigger.compiler.Trigger(
name="notify_enqueueing",
sql=pgtrigger.compiler.UpsertTriggerSql(
condition="WHEN (NEW.\"state\" = 'queued')",
constraint="CONSTRAINT",
func="\n PERFORM pg_notify(\n 'authentik.tasks' || NEW.queue_name || '.enqueue',\n CASE WHEN octet_length(NEW.message::text) >= 8000\n THEN jsonb_build_object('message_id', NEW.message_id)::text\n ELSE message::text\n END\n );\n RETURN NEW;\n ",
hash="d604c5647f3821f100e8aa7a52be181bde9ebdce",
operation="INSERT OR UPDATE",
pgid="pgtrigger_notify_enqueueing_b1977",
table="authentik_tasks_queue",
timing="DEFERRABLE INITIALLY DEFERRED",
when="AFTER",
),
),
),
]

View File

@ -1,10 +1,19 @@
from enum import StrEnum, auto
from uuid import uuid4
import pgtrigger
from django.db import models
from django.utils import timezone
from authentik.tenants.models import Tenant
CHANNEL_PREFIX = "authentik.tasks"
class ChannelIdentifier(StrEnum):
ENQUEUE = auto()
LOCK = auto()
class Queue(models.Model):
class State(models.TextChoices):
@ -24,6 +33,25 @@ class Queue(models.Model):
class Meta:
indexes = (models.Index(fields=("state", "mtime")),)
triggers = (
pgtrigger.Trigger(
name="notify_enqueueing",
operation=pgtrigger.Insert | pgtrigger.Update,
when=pgtrigger.After,
condition=pgtrigger.Q(new__state="queued"),
timing=pgtrigger.Deferred,
func=f"""
PERFORM pg_notify(
'{CHANNEL_PREFIX}' || NEW.queue_name || '.{ChannelIdentifier.ENQUEUE.value}',
CASE WHEN octet_length(NEW.message::text) >= 8000
THEN jsonb_build_object('message_id', NEW.message_id)::text
ELSE message::text
END
);
RETURN NEW;
""", # noqa: E501
),
)
def __str__(self):
return str(self.message_id)

View File

@ -29,10 +29,19 @@ class PostgresBackend(ResultBackend):
return self.encoder.decode(data)
def _store(self, message_key: str, result: Result, ttl: int) -> None:
# TODO: update_or_create
encoder = get_encoder()
self.query_set.filter(message_id=message_key).update(
mtime=timezone.now(),
result=encoder.encode(result),
result_ttl=timezone.now() + timezone.timedelta(milliseconds=ttl),
)
query = {
"message_id": message_key,
}
defaults = {
"mtime": timezone.now(),
"result": encoder.encode(result),
"result_ttl": timezone.now() + timezone.timedelta(milliseconds=ttl),
}
create_defaults = {
**query,
**defaults,
"queue_name": "__RQ__",
"state": Queue.State.DONE,
}
self.query_set.update_or_create(**query, defaults=defaults, create_defaults=create_defaults)