Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-03-10 17:20:15 +01:00
parent 6662611347
commit ae211226ef
11 changed files with 162 additions and 61 deletions

View File

@ -1,47 +0,0 @@
"""Run worker"""
from sys import exit as sysexit
from tempfile import tempdir
from celery.apps.worker import Worker
from django.core.management.base import BaseCommand
from django.db import close_old_connections
from structlog.stdlib import get_logger
from authentik.lib.config import CONFIG
from authentik.lib.debug import start_debug_server
from authentik.root.celery import CELERY_APP
LOGGER = get_logger()
class Command(BaseCommand):
"""Run worker"""
def add_arguments(self, parser):
parser.add_argument(
"-b",
"--beat",
action="store_false",
help="When set, this worker will _not_ run Beat (scheduled) tasks",
)
def handle(self, **options):
LOGGER.debug("Celery options", **options)
close_old_connections()
start_debug_server()
worker: Worker = CELERY_APP.Worker(
no_color=False,
quiet=True,
optimization="fair",
autoscale=(CONFIG.get_int("worker.concurrency"), 1),
task_events=True,
beat=options.get("beat", True),
schedule_filename=f"{tempdir}/celerybeat-schedule",
queues=["authentik", "authentik_scheduled", "authentik_events"],
)
for task in CELERY_APP.tasks:
LOGGER.debug("Registered task", task=task)
worker.start()
sysexit(worker.exitcode)

View File

@ -17,7 +17,7 @@ class AuthentikTasksConfig(ManagedAppConfig):
dramatiq.set_encoder(JSONPickleEncoder())
broker = PostgresBroker()
broker.add_middleware(Prometheus())
# broker.add_middleware(Prometheus())
broker.add_middleware(AgeLimit(max_age=timedelta(days=30).total_seconds() * 1000))
broker.add_middleware(TimeLimit())
broker.add_middleware(Callbacks())

View File

@ -1,3 +1,4 @@
from psycopg import sql
import functools
import logging
import time
@ -118,7 +119,7 @@ class PostgresBroker(Broker):
self.declare_queue(canonical_queue_name)
self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}")
self.emit_before("enqueue", message, delay)
encoded = message.encode()
encoded = message.encode().decode()
query = {
"message_id": message.message_id,
}
@ -212,7 +213,12 @@ class _PostgresConsumer(Consumer):
# Should be set to True by Django by default
self._listen_connection.set_autocommit(True)
with self._listen_connection.cursor() as cursor:
cursor.execute("LISTEN %s", channel_name(self.queue_name, ChannelIdentifier.ENQUEUE))
cursor.execute(
sql.SQL("LISTEN {}").format(
sql.Identifier(channel_name(self.queue_name, ChannelIdentifier.ENQUEUE))
)
)
return self._listen_connection
@raise_connection_error
def ack(self, message: Message):
@ -223,7 +229,7 @@ class _PostgresConsumer(Consumer):
state=TaskState.CONSUMED,
).update(
state=TaskState.DONE,
message=message.encode(),
message=message.encode().decode(),
)
self.in_processing.remove(message.message_id)
@ -236,7 +242,7 @@ class _PostgresConsumer(Consumer):
state__ne=TaskState.REJECTED,
).update(
state=TaskState.REJECTED,
message=message.encode(),
message=message.encode().decode(),
)
self.in_processing.remove(message.message_id)
@ -259,7 +265,7 @@ class _PostgresConsumer(Consumer):
def _poll_for_notify(self):
with self.listen_connection.cursor() as cursor:
notifies = cursor.notifies(timeout=self.timeout)
notifies = list(cursor.connection.notifies(timeout=self.timeout))
self.logger.debug(f"Received {len(notifies)} postgres notifies")
self.notifies += notifies
@ -278,8 +284,8 @@ class _PostgresConsumer(Consumer):
message_id=message.message_id,
state__in=(TaskState.QUEUED, TaskState.CONSUMED),
)
.update(state=TaskState.CONSUMED, mtime=timezone.now())
.extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)])
.update(state=TaskState.CONSUMED, mtime=timezone.now())
)
return result == 1
@ -334,7 +340,7 @@ class _PostgresConsumer(Consumer):
# No message to process
self._purge_locks()
self._auto_pruge()
self._auto_purge()
def _purge_locks(self):
while True:
@ -344,7 +350,9 @@ class _PostgresConsumer(Consumer):
return
self.logger.debug(f"Unlocking {message.message_id}@{message.queue_name}")
with self.connection.cursor() as cursor:
cursor.execute("SELECT pg_advisory_unlock(%s)", self._get_message_lock_id(message))
cursor.execute(
"SELECT pg_advisory_unlock(%s)", (self._get_message_lock_id(message),)
)
self.unlock_queue.task_done()
def _auto_purge(self):

View File

@ -21,11 +21,11 @@ class JSONPickleEncoder(dramatiq.encoder.Encoder):
keys=True,
warn=True,
use_base85=True,
).encode("utf-8")
).encode()
def decode(self, data: bytes) -> MessageData:
return jsonpickle.decode(
data.decode("utf-8"),
data.decode(),
backend=OrjsonBackend(),
keys=True,
on_missing="warn",

View File

View File

@ -0,0 +1,86 @@
import os
from django.utils.module_loading import module_has_submodule
from authentik.lib.utils.reflection import get_apps
import sys
from django.core.management.base import BaseCommand
class Command(BaseCommand):
"""Run worker"""
def add_arguments(self, parser):
parser.add_argument(
"--reload",
action="store_true",
dest="use_watcher",
help="Enable autoreload",
)
parser.add_argument(
"--reload-use-polling",
action="store_true",
dest="use_polling_watcher",
help="Use a poll-based file watcher for autoreload",
)
parser.add_argument(
"--use-gevent",
action="store_true",
help="Use gevent for worker concurrency",
)
parser.add_argument(
"--processes",
"-p",
default=1,
type=int,
help="The number of processes to run",
)
parser.add_argument(
"--threads",
"-t",
default=1,
type=int,
help="The number of threads per process to use",
)
def handle(
self, use_watcher, use_polling_watcher, use_gevent, processes, threads, verbosity, **options
):
executable_name = "dramatiq-gevent" if use_gevent else "dramatiq"
executable_path = self._resolve_executable(executable_name)
watch_args = ["--watch", "."] if use_watcher else []
if watch_args and use_polling_watcher:
watch_args.append("--watch-use-polling")
verbosity_args = ["-v"] * (verbosity - 1)
tasks_modules = self._discover_tasks_modules()
process_args = [
executable_name,
"--path",
".",
"--processes",
str(processes),
"--threads",
str(threads),
*watch_args,
*verbosity_args,
*tasks_modules,
]
os.execvp(executable_path, process_args)
def _resolve_executable(self, exec_name: str):
bin_dir = os.path.dirname(sys.executable)
if bin_dir:
for d in [bin_dir, os.path.join(bin_dir, "Scripts")]:
exec_path = os.path.join(d, exec_name)
if os.path.isfile(exec_path):
return exec_path
return exec_name
def _discover_tasks_modules(self) -> list[str]:
# Does not support a tasks directory
return ["authentik.tasks.setup"] + [
f"{app.name}.tasks" for app in get_apps() if module_has_submodule(app.module, "tasks")
]

View File

@ -62,7 +62,7 @@ class Migration(migrations.Migration):
},
),
migrations.RunSQL(
"ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop
"ALTER TABLE authentik_tasks_task SET WITHOUT OIDS;", migrations.RunSQL.noop
),
pgtrigger.migrations.AddTrigger(
model_name="task",

View File

@ -48,8 +48,8 @@ class Task(SerializerModel):
condition=pgtrigger.Q(new__state=TaskState.QUEUED),
timing=pgtrigger.Deferred,
func=f"""
PERFORM pg_notify(
'{CHANNEL_PREFIX}' || NEW.queue_name || '.{ChannelIdentifier.ENQUEUE.value}',
SELECT pg_notify(
'{CHANNEL_PREFIX}.' || NEW.queue_name || '.{ChannelIdentifier.ENQUEUE.value}',
CASE WHEN octet_length(NEW.message::text) >= 8000
THEN jsonb_build_object('message_id', NEW.message_id)::text
ELSE NEW.message::text

45
authentik/tasks/setup.py Normal file
View File

@ -0,0 +1,45 @@
import os
import sys
import warnings
from authentik.lib.config import CONFIG
from cryptography.hazmat.backends.openssl.backend import backend
from defusedxml import defuse_stdlib
from django.utils.autoreload import DJANGO_AUTORELOAD_ENV
from lifecycle.migrate import run_migrations
from lifecycle.wait_for_db import wait_for_db
warnings.filterwarnings("ignore", "SelectableGroups dict interface")
warnings.filterwarnings(
"ignore",
"defusedxml.lxml is no longer supported and will be removed in a future release.",
)
warnings.filterwarnings(
"ignore",
"defusedxml.cElementTree is deprecated, import from defusedxml.ElementTree instead.",
)
defuse_stdlib()
if CONFIG.get_bool("compliance.fips.enabled", False):
backend._enable_fips()
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
wait_for_db()
print(sys.argv)
if (
len(sys.argv) > 1
# Explicitly only run migrate for server and worker
# `bootstrap_tasks` is a special case as that command might be triggered by the `ak`
# script to pre-run certain tasks for an automated install
and sys.argv[1] in ["dev_server", "worker", "bootstrap_tasks"]
# and don't run if this is the child process of a dev_server
and os.environ.get(DJANGO_AUTORELOAD_ENV, None) is None
):
run_migrations()
import django
django.setup()

9
authentik/tasks/tasks.py Normal file
View File

@ -0,0 +1,9 @@
from dramatiq import actor
@actor
def test_actor():
import time
time.sleep(5)
print("done sleeping")