@ -13,6 +13,7 @@ class AuthentikTasksConfig(ManagedAppConfig):
|
|||||||
|
|
||||||
def ready(self) -> None:
|
def ready(self) -> None:
|
||||||
from authentik.tasks.broker import PostgresBroker
|
from authentik.tasks.broker import PostgresBroker
|
||||||
|
from authentik.tasks.middleware import CurrentTask
|
||||||
|
|
||||||
dramatiq.set_encoder(JSONPickleEncoder())
|
dramatiq.set_encoder(JSONPickleEncoder())
|
||||||
broker = PostgresBroker()
|
broker = PostgresBroker()
|
||||||
@ -21,5 +22,6 @@ class AuthentikTasksConfig(ManagedAppConfig):
|
|||||||
broker.add_middleware(TimeLimit())
|
broker.add_middleware(TimeLimit())
|
||||||
broker.add_middleware(Callbacks())
|
broker.add_middleware(Callbacks())
|
||||||
broker.add_middleware(Retries(max_retries=3))
|
broker.add_middleware(Retries(max_retries=3))
|
||||||
|
broker.add_middleware(CurrentTask())
|
||||||
dramatiq.set_broker(broker)
|
dramatiq.set_broker(broker)
|
||||||
return super().ready()
|
return super().ready()
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from psycopg import Notify
|
|||||||
from psycopg.errors import AdminShutdown
|
from psycopg.errors import AdminShutdown
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.tasks.models import Queue as MQueue, CHANNEL_PREFIX, ChannelIdentifier
|
from authentik.tasks.models import Task, CHANNEL_PREFIX, ChannelIdentifier
|
||||||
from authentik.tasks.results import PostgresBackend
|
from authentik.tasks.results import PostgresBackend
|
||||||
from authentik.tenants.utils import get_current_tenant
|
from authentik.tenants.utils import get_current_tenant
|
||||||
|
|
||||||
@ -65,7 +65,7 @@ class PostgresBroker(Broker):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def query_set(self) -> QuerySet:
|
def query_set(self) -> QuerySet:
|
||||||
return MQueue.objects.using(self.db_alias)
|
return Task.objects.using(self.db_alias)
|
||||||
|
|
||||||
def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer:
|
def consume(self, queue_name: str, prefetch: int = 1, timeout: int = 30000) -> Consumer:
|
||||||
self.declare_queue(queue_name)
|
self.declare_queue(queue_name)
|
||||||
@ -125,7 +125,7 @@ class PostgresBroker(Broker):
|
|||||||
defaults = {
|
defaults = {
|
||||||
"tenant": get_current_tenant(),
|
"tenant": get_current_tenant(),
|
||||||
"queue_name": message.queue_name,
|
"queue_name": message.queue_name,
|
||||||
"state": MQueue.State.QUEUED,
|
"state": Task.State.QUEUED,
|
||||||
"message": encoded,
|
"message": encoded,
|
||||||
}
|
}
|
||||||
create_defaults = {
|
create_defaults = {
|
||||||
@ -169,7 +169,7 @@ class PostgresBroker(Broker):
|
|||||||
if (
|
if (
|
||||||
self.query_set.filter(
|
self.query_set.filter(
|
||||||
queue_name=queue_name,
|
queue_name=queue_name,
|
||||||
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED),
|
state__in=(Task.State.QUEUED, Task.State.CONSUMED),
|
||||||
)
|
)
|
||||||
== 0
|
== 0
|
||||||
):
|
):
|
||||||
@ -200,7 +200,7 @@ class _PostgresConsumer(Consumer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def query_set(self) -> QuerySet:
|
def query_set(self) -> QuerySet:
|
||||||
return MQueue.objects.using(self.db_alias)
|
return Task.objects.using(self.db_alias)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def listen_connection(self) -> DatabaseWrapper:
|
def listen_connection(self) -> DatabaseWrapper:
|
||||||
@ -220,9 +220,9 @@ class _PostgresConsumer(Consumer):
|
|||||||
self.query_set.filter(
|
self.query_set.filter(
|
||||||
message_id=message.message_id,
|
message_id=message.message_id,
|
||||||
queue_name=message.queue_name,
|
queue_name=message.queue_name,
|
||||||
state=MQueue.State.CONSUMED,
|
state=Task.State.CONSUMED,
|
||||||
).update(
|
).update(
|
||||||
state=MQueue.State.DONE,
|
state=Task.State.DONE,
|
||||||
message=message.encode(),
|
message=message.encode(),
|
||||||
)
|
)
|
||||||
self.in_processing.remove(message.message_id)
|
self.in_processing.remove(message.message_id)
|
||||||
@ -233,9 +233,9 @@ class _PostgresConsumer(Consumer):
|
|||||||
self.query_set.filter(
|
self.query_set.filter(
|
||||||
message_id=message.message_id,
|
message_id=message.message_id,
|
||||||
queue_name=message.queue_name,
|
queue_name=message.queue_name,
|
||||||
state__ne=MQueue.State.REJECTED,
|
state__ne=Task.State.REJECTED,
|
||||||
).update(
|
).update(
|
||||||
state=MQueue.State.REJECTED,
|
state=Task.State.REJECTED,
|
||||||
message=message.encode(),
|
message=message.encode(),
|
||||||
)
|
)
|
||||||
self.in_processing.remove(message.message_id)
|
self.in_processing.remove(message.message_id)
|
||||||
@ -245,14 +245,14 @@ class _PostgresConsumer(Consumer):
|
|||||||
self.query_set.filter(
|
self.query_set.filter(
|
||||||
message_id__in=[message.message_id for message in messages],
|
message_id__in=[message.message_id for message in messages],
|
||||||
).update(
|
).update(
|
||||||
state=MQueue.State.QUEUED,
|
state=Task.State.QUEUED,
|
||||||
)
|
)
|
||||||
# We don't care about locks, requeue occurs on worker stop
|
# We don't care about locks, requeue occurs on worker stop
|
||||||
|
|
||||||
def _fetch_pending_notifies(self) -> list[Notify]:
|
def _fetch_pending_notifies(self) -> list[Notify]:
|
||||||
self.logger.debug(f"Polling for lost messages in {self.queue_name}")
|
self.logger.debug(f"Polling for lost messages in {self.queue_name}")
|
||||||
notifies = self.query_set.filter(
|
notifies = self.query_set.filter(
|
||||||
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED), queue_name=self.queue_name
|
state__in=(Task.State.QUEUED, Task.State.CONSUMED), queue_name=self.queue_name
|
||||||
)
|
)
|
||||||
channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE)
|
channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE)
|
||||||
return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies]
|
return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies]
|
||||||
@ -276,9 +276,9 @@ class _PostgresConsumer(Consumer):
|
|||||||
result = (
|
result = (
|
||||||
self.query_set.filter(
|
self.query_set.filter(
|
||||||
message_id=message.message_id,
|
message_id=message.message_id,
|
||||||
state__in=(MQueue.State.QUEUED, MQueue.State.CONSUMED),
|
state__in=(Task.State.QUEUED, Task.State.CONSUMED),
|
||||||
)
|
)
|
||||||
.update(state=MQueue.State.CONSUMED, mtime=timezone.now())
|
.update(state=Task.State.CONSUMED, mtime=timezone.now())
|
||||||
.extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)])
|
.extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)])
|
||||||
)
|
)
|
||||||
return result == 1
|
return result == 1
|
||||||
@ -354,7 +354,7 @@ class _PostgresConsumer(Consumer):
|
|||||||
return
|
return
|
||||||
self.logger.debug("Running garbage collector")
|
self.logger.debug("Running garbage collector")
|
||||||
count = self.query_set.filter(
|
count = self.query_set.filter(
|
||||||
state__in=(MQueue.State.DONE, MQueue.State.REJECTED),
|
state__in=(Task.State.DONE, Task.State.REJECTED),
|
||||||
mtime__lte=timezone.now() - timezone.timedelta(days=30),
|
mtime__lte=timezone.now() - timezone.timedelta(days=30),
|
||||||
).delete()
|
).delete()
|
||||||
self.logger.info(f"Purged {count} messages in all queues")
|
self.logger.info(f"Purged {count} messages in all queues")
|
||||||
|
|||||||
19
authentik/tasks/middleware.py
Normal file
19
authentik/tasks/middleware.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import contextvars
|
||||||
|
from dramatiq.message import Message
|
||||||
|
from dramatiq.middleware import Middleware
|
||||||
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
|
|
||||||
|
class CurrentTask(Middleware):
|
||||||
|
_TASK: contextvars.ContextVar[Task | None] = contextvars.ContextVar("_TASK", default=None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_task(cls) -> Task | None:
|
||||||
|
return cls._TASK.get()
|
||||||
|
|
||||||
|
def before_process_message(self, _, message: Message):
|
||||||
|
self._TASK.set(Task.objects.get(message_id=message.message_id))
|
||||||
|
|
||||||
|
def after_process_message(self, *args, **kwargs):
|
||||||
|
self._TASK.get().save()
|
||||||
|
self._TASK.set(None)
|
||||||
@ -1,79 +0,0 @@
|
|||||||
# Generated by Django 5.0.12 on 2025-03-09 00:47
|
|
||||||
|
|
||||||
import django.db.models.deletion
|
|
||||||
import django.utils.timezone
|
|
||||||
import pgtrigger.compiler
|
|
||||||
import pgtrigger.migrations
|
|
||||||
import uuid
|
|
||||||
from django.db import migrations, models
|
|
||||||
|
|
||||||
|
|
||||||
class Migration(migrations.Migration):
|
|
||||||
|
|
||||||
initial = True
|
|
||||||
|
|
||||||
dependencies = [
|
|
||||||
("authentik_tenants", "0004_tenant_impersonation_require_reason"),
|
|
||||||
]
|
|
||||||
|
|
||||||
operations = [
|
|
||||||
migrations.CreateModel(
|
|
||||||
name="Queue",
|
|
||||||
fields=[
|
|
||||||
(
|
|
||||||
"message_id",
|
|
||||||
models.UUIDField(
|
|
||||||
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
|
|
||||||
),
|
|
||||||
),
|
|
||||||
("queue_name", models.TextField(default="default")),
|
|
||||||
(
|
|
||||||
"state",
|
|
||||||
models.CharField(
|
|
||||||
choices=[
|
|
||||||
("queued", "Queued"),
|
|
||||||
("consumed", "Consumed"),
|
|
||||||
("rejected", "Rejected"),
|
|
||||||
("done", "Done"),
|
|
||||||
],
|
|
||||||
default="queued",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
("mtime", models.DateTimeField(default=django.utils.timezone.now)),
|
|
||||||
("message", models.JSONField(blank=True, null=True)),
|
|
||||||
("result", models.JSONField(blank=True, null=True)),
|
|
||||||
("result_ttl", models.DateTimeField(blank=True, null=True)),
|
|
||||||
(
|
|
||||||
"tenant",
|
|
||||||
models.ForeignKey(
|
|
||||||
on_delete=django.db.models.deletion.CASCADE, to="authentik_tenants.tenant"
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
options={
|
|
||||||
"indexes": [
|
|
||||||
models.Index(fields=["state", "mtime"], name="authentik_t_state_b7ff76_idx")
|
|
||||||
],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
migrations.RunSQL(
|
|
||||||
"ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop
|
|
||||||
),
|
|
||||||
pgtrigger.migrations.AddTrigger(
|
|
||||||
model_name="queue",
|
|
||||||
trigger=pgtrigger.compiler.Trigger(
|
|
||||||
name="notify_enqueueing",
|
|
||||||
sql=pgtrigger.compiler.UpsertTriggerSql(
|
|
||||||
condition="WHEN (NEW.\"state\" = 'queued')",
|
|
||||||
constraint="CONSTRAINT",
|
|
||||||
func="\n PERFORM pg_notify(\n 'authentik.tasks' || NEW.queue_name || '.enqueue',\n CASE WHEN octet_length(NEW.message::text) >= 8000\n THEN jsonb_build_object('message_id', NEW.message_id)::text\n ELSE message::text\n END\n );\n RETURN NEW;\n ",
|
|
||||||
hash="d604c5647f3821f100e8aa7a52be181bde9ebdce",
|
|
||||||
operation="INSERT OR UPDATE",
|
|
||||||
pgid="pgtrigger_notify_enqueueing_b1977",
|
|
||||||
table="authentik_tasks_queue",
|
|
||||||
timing="DEFERRABLE INITIALLY DEFERRED",
|
|
||||||
when="AFTER",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
@ -5,6 +5,7 @@ import pgtrigger
|
|||||||
from django.db import models
|
from django.db import models
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
|
from authentik.lib.models import SerializerModel
|
||||||
from authentik.tenants.models import Tenant
|
from authentik.tenants.models import Tenant
|
||||||
|
|
||||||
CHANNEL_PREFIX = "authentik.tasks"
|
CHANNEL_PREFIX = "authentik.tasks"
|
||||||
@ -15,21 +16,23 @@ class ChannelIdentifier(StrEnum):
|
|||||||
LOCK = auto()
|
LOCK = auto()
|
||||||
|
|
||||||
|
|
||||||
class Queue(models.Model):
|
class Task(SerializerModel):
|
||||||
class State(models.TextChoices):
|
class State(models.TextChoices):
|
||||||
QUEUED = "queued"
|
QUEUED = "queued"
|
||||||
CONSUMED = "consumed"
|
CONSUMED = "consumed"
|
||||||
REJECTED = "rejected"
|
REJECTED = "rejected"
|
||||||
DONE = "done"
|
DONE = "done"
|
||||||
|
|
||||||
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
|
|
||||||
message_id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
message_id = models.UUIDField(primary_key=True, default=uuid4, editable=False)
|
||||||
queue_name = models.TextField(default="default")
|
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE, editable=False)
|
||||||
state = models.CharField(default=State.QUEUED, choices=State.choices)
|
queue_name = models.TextField(default="default", editable=False)
|
||||||
mtime = models.DateTimeField(default=timezone.now)
|
state = models.CharField(default=State.QUEUED, choices=State.choices, editable=False)
|
||||||
message = models.JSONField(blank=True, null=True)
|
mtime = models.DateTimeField(default=timezone.now, editable=False)
|
||||||
result = models.JSONField(blank=True, null=True)
|
message = models.JSONField(blank=True, null=True, editable=False)
|
||||||
result_ttl = models.DateTimeField(blank=True, null=True)
|
result = models.JSONField(blank=True, null=True, editable=False)
|
||||||
|
result_ttl = models.DateTimeField(blank=True, null=True, editable=False)
|
||||||
|
description = models.TextField(blank=True)
|
||||||
|
messages = models.JSONField(blank=True, null=True, editable=False)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
indexes = (models.Index(fields=("state", "mtime")),)
|
indexes = (models.Index(fields=("state", "mtime")),)
|
||||||
@ -55,3 +58,8 @@ class Queue(models.Model):
|
|||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return str(self.message_id)
|
return str(self.message_id)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def serializer(self):
|
||||||
|
# TODO: fixme
|
||||||
|
pass
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from django.utils import timezone
|
|||||||
from dramatiq.message import Message, get_encoder
|
from dramatiq.message import Message, get_encoder
|
||||||
from dramatiq.results.backend import Missing, MResult, Result, ResultBackend
|
from dramatiq.results.backend import Missing, MResult, Result, ResultBackend
|
||||||
|
|
||||||
from authentik.tasks.models import Queue
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
|
|
||||||
class PostgresBackend(ResultBackend):
|
class PostgresBackend(ResultBackend):
|
||||||
@ -14,7 +14,7 @@ class PostgresBackend(ResultBackend):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def query_set(self) -> QuerySet:
|
def query_set(self) -> QuerySet:
|
||||||
return Queue.objects.using(self.db_alias)
|
return Task.objects.using(self.db_alias)
|
||||||
|
|
||||||
def build_message_key(self, message: Message) -> str:
|
def build_message_key(self, message: Message) -> str:
|
||||||
return str(message.message_id)
|
return str(message.message_id)
|
||||||
|
|||||||
Reference in New Issue
Block a user