Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-03-11 17:13:07 +01:00
parent ae211226ef
commit 3ddc35cddc
5 changed files with 29 additions and 13 deletions

View File

@ -259,9 +259,9 @@ class _PostgresConsumer(Consumer):
self.logger.debug(f"Polling for lost messages in {self.queue_name}")
notifies = self.query_set.filter(
state__in=(TaskState.QUEUED, TaskState.CONSUMED), queue_name=self.queue_name
)
).values_list("message_id", flat=True)
channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE)
return [Notify(pid=0, channel=channel, payload=item.message) for item in notifies]
return [Notify(pid=0, channel=channel, payload=item) for item in notifies]
def _poll_for_notify(self):
with self.listen_connection.cursor() as cursor:

View File

@ -1,4 +1,4 @@
# Generated by Django 5.0.12 on 2025-03-09 12:37
# Generated by Django 5.0.12 on 2025-03-11 16:08
import django.db.models.deletion
import django.utils.timezone
@ -41,8 +41,8 @@ class Migration(migrations.Migration):
),
),
("mtime", models.DateTimeField(default=django.utils.timezone.now, editable=False)),
("message", models.JSONField(editable=False, null=True)),
("result", models.JSONField(editable=False, null=True)),
("message", models.TextField(editable=False)),
("result", models.TextField(editable=False, null=True)),
("result_ttl", models.DateTimeField(editable=False, null=True)),
("description", models.TextField(blank=True)),
("messages", models.JSONField(blank=True, editable=False, null=True)),
@ -61,9 +61,6 @@ class Migration(migrations.Migration):
],
},
),
migrations.RunSQL(
"ALTER TABLE authentik_tasks_task SET WITHOUT OIDS;", migrations.RunSQL.noop
),
pgtrigger.migrations.AddTrigger(
model_name="task",
trigger=pgtrigger.compiler.Trigger(
@ -71,8 +68,8 @@ class Migration(migrations.Migration):
sql=pgtrigger.compiler.UpsertTriggerSql(
condition="WHEN (NEW.\"state\" = 'queued')",
constraint="CONSTRAINT",
func="\n PERFORM pg_notify(\n 'authentik.tasks' || NEW.queue_name || '.enqueue',\n CASE WHEN octet_length(NEW.message::text) >= 8000\n THEN jsonb_build_object('message_id', NEW.message_id)::text\n ELSE NEW.message::text\n END\n );\n RETURN NEW;\n ",
hash="37184bcf29160694f794426a0246c3c1a5e8e702",
func="\n PERFORM pg_notify(\n 'authentik.tasks.' || NEW.queue_name || '.enqueue',\n CASE WHEN octet_length(NEW.message::text) >= 8000\n THEN jsonb_build_object('message_id', NEW.message_id)::text\n ELSE NEW.message::text\n END\n );\n RETURN NEW;\n ",
hash="97159b94da81ceb034d235647ea771897a769f50",
operation="INSERT OR UPDATE",
pgid="pgtrigger_notify_enqueueing_0bc94",
table="authentik_tasks_task",

View File

@ -0,0 +1,18 @@
# Generated by Django 5.0.12 on 2025-03-11 16:10
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_tasks", "0001_initial"),
]
operations = [
migrations.AlterField(
model_name="task",
name="message",
field=models.TextField(editable=False, null=True),
),
]

View File

@ -30,9 +30,9 @@ class Task(SerializerModel):
tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE, editable=False)
state = models.CharField(default=TaskState.QUEUED, choices=TaskState.choices, editable=False)
mtime = models.DateTimeField(default=timezone.now, editable=False)
message = models.JSONField(null=True, editable=False)
message = models.TextField(null=True, editable=False)
result = models.JSONField(null=True, editable=False)
result = models.TextField(null=True, editable=False)
result_ttl = models.DateTimeField(null=True, editable=False)
description = models.TextField(blank=True)
@ -48,7 +48,7 @@ class Task(SerializerModel):
condition=pgtrigger.Q(new__state=TaskState.QUEUED),
timing=pgtrigger.Deferred,
func=f"""
SELECT pg_notify(
PERFORM pg_notify(
'{CHANNEL_PREFIX}.' || NEW.queue_name || '.{ChannelIdentifier.ENQUEUE.value}',
CASE WHEN octet_length(NEW.message::text) >= 8000
THEN jsonb_build_object('message_id', NEW.message_id)::text

View File

@ -38,6 +38,7 @@ class PostgresBackend(ResultBackend):
"result": encoder.encode(result),
"result_ttl": timezone.now() + timezone.timedelta(milliseconds=ttl),
}
# TODO: tenant
create_defaults = {
**query,
**defaults,