@ -73,6 +73,7 @@ TENANT_APPS = [
|
|||||||
"django.contrib.auth",
|
"django.contrib.auth",
|
||||||
"django.contrib.contenttypes",
|
"django.contrib.contenttypes",
|
||||||
"django.contrib.sessions",
|
"django.contrib.sessions",
|
||||||
|
"pgtrigger",
|
||||||
"authentik.admin",
|
"authentik.admin",
|
||||||
"authentik.api",
|
"authentik.api",
|
||||||
"authentik.crypto",
|
"authentik.crypto",
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
from enum import Enum, StrEnum, auto
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
@ -22,20 +21,15 @@ from psycopg import Notify
|
|||||||
from psycopg.errors import AdminShutdown
|
from psycopg.errors import AdminShutdown
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.tasks.models import Queue as MQueue
|
from authentik.tasks.models import Queue as MQueue, CHANNEL_PREFIX, ChannelIdentifier
|
||||||
from authentik.tasks.results import PostgresBackend
|
from authentik.tasks.results import PostgresBackend
|
||||||
from authentik.tenants.utils import get_current_tenant
|
from authentik.tenants.utils import get_current_tenant
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class ChannelIdentifier(StrEnum):
|
|
||||||
ENQUEUE = auto()
|
|
||||||
LOCK = auto()
|
|
||||||
|
|
||||||
|
|
||||||
def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
|
def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
|
||||||
return f"authentik.tasks.{queue_name}.{identifier.value}"
|
return f"{CHANNEL_PREFIX}.{queue_name}.{identifier.value}"
|
||||||
|
|
||||||
|
|
||||||
def raise_connection_error(func):
|
def raise_connection_error(func):
|
||||||
@ -62,7 +56,11 @@ class PostgresBroker(Broker):
|
|||||||
self.add_middleware(Results(backend=self.backend))
|
self.add_middleware(Results(backend=self.backend))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def consumer_class(self):
|
def connection(self) -> DatabaseWrapper:
|
||||||
|
return connections[self.db_alias]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def consumer_class(self) -> "type[_PostgresConsumer]":
|
||||||
return _PostgresConsumer
|
return _PostgresConsumer
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -120,14 +118,24 @@ class PostgresBroker(Broker):
|
|||||||
self.declare_queue(canonical_queue_name)
|
self.declare_queue(canonical_queue_name)
|
||||||
self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}")
|
self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}")
|
||||||
self.emit_before("enqueue", message, delay)
|
self.emit_before("enqueue", message, delay)
|
||||||
# TODO: notify
|
encoded = message.encode()
|
||||||
# TODO: update_or_create
|
query = {
|
||||||
self.query_set.create(
|
"message_id": message.message_id,
|
||||||
message_id=message.message_id,
|
}
|
||||||
tenant=get_current_tenant(),
|
defaults = {
|
||||||
queue_name=message.queue_name,
|
"tenant": get_current_tenant(),
|
||||||
state=MQueue.State.QUEUED,
|
"queue_name": message.queue_name,
|
||||||
message=message.encode(),
|
"state": MQueue.State.QUEUED,
|
||||||
|
"message": encoded,
|
||||||
|
}
|
||||||
|
create_defaults = {
|
||||||
|
**query,
|
||||||
|
**defaults,
|
||||||
|
}
|
||||||
|
self.query_set.update_or_create(
|
||||||
|
**query,
|
||||||
|
defaults=defaults,
|
||||||
|
create_defaults=create_defaults,
|
||||||
)
|
)
|
||||||
self.emit_after("enqueue", message, delay)
|
self.emit_after("enqueue", message, delay)
|
||||||
return message
|
return message
|
||||||
@ -227,7 +235,7 @@ class _PostgresConsumer(Consumer):
|
|||||||
queue_name=message.queue_name,
|
queue_name=message.queue_name,
|
||||||
state__ne=MQueue.State.REJECTED,
|
state__ne=MQueue.State.REJECTED,
|
||||||
).update(
|
).update(
|
||||||
state=MQueue.State.REJECT,
|
state=MQueue.State.REJECTED,
|
||||||
message=message.encode(),
|
message=message.encode(),
|
||||||
)
|
)
|
||||||
self.in_processing.remove(message.message_id)
|
self.in_processing.remove(message.message_id)
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
# Generated by Django 5.0.12 on 2025-03-08 22:39
|
# Generated by Django 5.0.12 on 2025-03-09 00:47
|
||||||
|
|
||||||
|
import django.db.models.deletion
|
||||||
import django.utils.timezone
|
import django.utils.timezone
|
||||||
|
import pgtrigger.compiler
|
||||||
|
import pgtrigger.migrations
|
||||||
import uuid
|
import uuid
|
||||||
from django.db import migrations, models
|
from django.db import migrations, models
|
||||||
|
|
||||||
@ -9,7 +12,9 @@ class Migration(migrations.Migration):
|
|||||||
|
|
||||||
initial = True
|
initial = True
|
||||||
|
|
||||||
dependencies = []
|
dependencies = [
|
||||||
|
("authentik_tenants", "0004_tenant_impersonation_require_reason"),
|
||||||
|
]
|
||||||
|
|
||||||
operations = [
|
operations = [
|
||||||
migrations.CreateModel(
|
migrations.CreateModel(
|
||||||
@ -38,6 +43,12 @@ class Migration(migrations.Migration):
|
|||||||
("message", models.JSONField(blank=True, null=True)),
|
("message", models.JSONField(blank=True, null=True)),
|
||||||
("result", models.JSONField(blank=True, null=True)),
|
("result", models.JSONField(blank=True, null=True)),
|
||||||
("result_ttl", models.DateTimeField(blank=True, null=True)),
|
("result_ttl", models.DateTimeField(blank=True, null=True)),
|
||||||
|
(
|
||||||
|
"tenant",
|
||||||
|
models.ForeignKey(
|
||||||
|
on_delete=django.db.models.deletion.CASCADE, to="authentik_tenants.tenant"
|
||||||
|
),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
options={
|
options={
|
||||||
"indexes": [
|
"indexes": [
|
||||||
@ -48,4 +59,21 @@ class Migration(migrations.Migration):
|
|||||||
migrations.RunSQL(
|
migrations.RunSQL(
|
||||||
"ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop
|
"ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop
|
||||||
),
|
),
|
||||||
|
pgtrigger.migrations.AddTrigger(
|
||||||
|
model_name="queue",
|
||||||
|
trigger=pgtrigger.compiler.Trigger(
|
||||||
|
name="notify_enqueueing",
|
||||||
|
sql=pgtrigger.compiler.UpsertTriggerSql(
|
||||||
|
condition="WHEN (NEW.\"state\" = 'queued')",
|
||||||
|
constraint="CONSTRAINT",
|
||||||
|
func="\n PERFORM pg_notify(\n 'authentik.tasks' || NEW.queue_name || '.enqueue',\n CASE WHEN octet_length(NEW.message::text) >= 8000\n THEN jsonb_build_object('message_id', NEW.message_id)::text\n ELSE message::text\n END\n );\n RETURN NEW;\n ",
|
||||||
|
hash="d604c5647f3821f100e8aa7a52be181bde9ebdce",
|
||||||
|
operation="INSERT OR UPDATE",
|
||||||
|
pgid="pgtrigger_notify_enqueueing_b1977",
|
||||||
|
table="authentik_tasks_queue",
|
||||||
|
timing="DEFERRABLE INITIALLY DEFERRED",
|
||||||
|
when="AFTER",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
@ -1,10 +1,19 @@
|
|||||||
|
from enum import StrEnum, auto
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
import pgtrigger
|
||||||
|
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
from authentik.tenants.models import Tenant
|
from authentik.tenants.models import Tenant
|
||||||
|
|
||||||
|
CHANNEL_PREFIX = "authentik.tasks"
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelIdentifier(StrEnum):
|
||||||
|
ENQUEUE = auto()
|
||||||
|
LOCK = auto()
|
||||||
|
|
||||||
|
|
||||||
class Queue(models.Model):
|
class Queue(models.Model):
|
||||||
class State(models.TextChoices):
|
class State(models.TextChoices):
|
||||||
@ -24,6 +33,25 @@ class Queue(models.Model):
|
|||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
indexes = (models.Index(fields=("state", "mtime")),)
|
indexes = (models.Index(fields=("state", "mtime")),)
|
||||||
|
triggers = (
|
||||||
|
pgtrigger.Trigger(
|
||||||
|
name="notify_enqueueing",
|
||||||
|
operation=pgtrigger.Insert | pgtrigger.Update,
|
||||||
|
when=pgtrigger.After,
|
||||||
|
condition=pgtrigger.Q(new__state="queued"),
|
||||||
|
timing=pgtrigger.Deferred,
|
||||||
|
func=f"""
|
||||||
|
PERFORM pg_notify(
|
||||||
|
'{CHANNEL_PREFIX}' || NEW.queue_name || '.{ChannelIdentifier.ENQUEUE.value}',
|
||||||
|
CASE WHEN octet_length(NEW.message::text) >= 8000
|
||||||
|
THEN jsonb_build_object('message_id', NEW.message_id)::text
|
||||||
|
ELSE message::text
|
||||||
|
END
|
||||||
|
);
|
||||||
|
RETURN NEW;
|
||||||
|
""", # noqa: E501
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return str(self.message_id)
|
return str(self.message_id)
|
||||||
|
@ -29,10 +29,19 @@ class PostgresBackend(ResultBackend):
|
|||||||
return self.encoder.decode(data)
|
return self.encoder.decode(data)
|
||||||
|
|
||||||
def _store(self, message_key: str, result: Result, ttl: int) -> None:
|
def _store(self, message_key: str, result: Result, ttl: int) -> None:
|
||||||
# TODO: update_or_create
|
|
||||||
encoder = get_encoder()
|
encoder = get_encoder()
|
||||||
self.query_set.filter(message_id=message_key).update(
|
query = {
|
||||||
mtime=timezone.now(),
|
"message_id": message_key,
|
||||||
result=encoder.encode(result),
|
}
|
||||||
result_ttl=timezone.now() + timezone.timedelta(milliseconds=ttl),
|
defaults = {
|
||||||
)
|
"mtime": timezone.now(),
|
||||||
|
"result": encoder.encode(result),
|
||||||
|
"result_ttl": timezone.now() + timezone.timedelta(milliseconds=ttl),
|
||||||
|
}
|
||||||
|
create_defaults = {
|
||||||
|
**query,
|
||||||
|
**defaults,
|
||||||
|
"queue_name": "__RQ__",
|
||||||
|
"state": Queue.State.DONE,
|
||||||
|
}
|
||||||
|
self.query_set.update_or_create(**query, defaults=defaults, create_defaults=create_defaults)
|
||||||
|
17
poetry.lock
generated
17
poetry.lock
generated
@ -1534,6 +1534,21 @@ files = [
|
|||||||
django = ">=4"
|
django = ">=4"
|
||||||
django-pgactivity = ">=1.2,<2"
|
django-pgactivity = ">=1.2,<2"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "django-pgtrigger"
|
||||||
|
version = "4.13.3"
|
||||||
|
description = "Postgres trigger support integrated with Django models."
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = "<4,>=3.9.0"
|
||||||
|
files = [
|
||||||
|
{file = "django_pgtrigger-4.13.3-py3-none-any.whl", hash = "sha256:d6e4d17021bbd5e425a308f07414b237b9b34423275d86ad756b90c307df3ca4"},
|
||||||
|
{file = "django_pgtrigger-4.13.3.tar.gz", hash = "sha256:c525f9e81f120d166c4bd5fe8c3770640356f0644edf0fc2b7f6426008e52f77"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
django = ">=4"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "django-prometheus"
|
name = "django-prometheus"
|
||||||
version = "2.3.1"
|
version = "2.3.1"
|
||||||
@ -6341,4 +6356,4 @@ files = [
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "~3.12"
|
python-versions = "~3.12"
|
||||||
content-hash = "2c4a0e7ad08377212eb8ef2c3a372a557a407c43f2e2b3429403598477106c55"
|
content-hash = "1f8cfdb90051911b2d677a9fd9ad4de8afea6fd8644de4bc25d418652506b7e2"
|
||||||
|
@ -124,6 +124,7 @@ django-filter = "*"
|
|||||||
django-guardian = "*"
|
django-guardian = "*"
|
||||||
django-model-utils = "*"
|
django-model-utils = "*"
|
||||||
django-pglock = "*"
|
django-pglock = "*"
|
||||||
|
django-pgtrigger = "*"
|
||||||
django-prometheus = "*"
|
django-prometheus = "*"
|
||||||
django-redis = "*"
|
django-redis = "*"
|
||||||
django-storages = { extras = ["s3"], version = "*" }
|
django-storages = { extras = ["s3"], version = "*" }
|
||||||
|
Reference in New Issue
Block a user