diff --git a/authentik/core/management/commands/worker.py b/authentik/core/management/commands/worker.py deleted file mode 100644 index 8b3ed9346c..0000000000 --- a/authentik/core/management/commands/worker.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Run worker""" - -from sys import exit as sysexit -from tempfile import tempdir - -from celery.apps.worker import Worker -from django.core.management.base import BaseCommand -from django.db import close_old_connections -from structlog.stdlib import get_logger - -from authentik.lib.config import CONFIG -from authentik.lib.debug import start_debug_server -from authentik.root.celery import CELERY_APP - -LOGGER = get_logger() - - -class Command(BaseCommand): - """Run worker""" - - def add_arguments(self, parser): - parser.add_argument( - "-b", - "--beat", - action="store_false", - help="When set, this worker will _not_ run Beat (scheduled) tasks", - ) - - def handle(self, **options): - LOGGER.debug("Celery options", **options) - close_old_connections() - start_debug_server() - worker: Worker = CELERY_APP.Worker( - no_color=False, - quiet=True, - optimization="fair", - autoscale=(CONFIG.get_int("worker.concurrency"), 1), - task_events=True, - beat=options.get("beat", True), - schedule_filename=f"{tempdir}/celerybeat-schedule", - queues=["authentik", "authentik_scheduled", "authentik_events"], - ) - for task in CELERY_APP.tasks: - LOGGER.debug("Registered task", task=task) - - worker.start() - sysexit(worker.exitcode) diff --git a/authentik/tasks/apps.py b/authentik/tasks/apps.py index 4061b0af6c..8f9ab2b034 100644 --- a/authentik/tasks/apps.py +++ b/authentik/tasks/apps.py @@ -17,7 +17,7 @@ class AuthentikTasksConfig(ManagedAppConfig): dramatiq.set_encoder(JSONPickleEncoder()) broker = PostgresBroker() - broker.add_middleware(Prometheus()) + # broker.add_middleware(Prometheus()) broker.add_middleware(AgeLimit(max_age=timedelta(days=30).total_seconds() * 1000)) broker.add_middleware(TimeLimit()) broker.add_middleware(Callbacks()) diff --git a/authentik/tasks/broker.py b/authentik/tasks/broker.py index 2c3040cb1d..88a1a44528 100644 --- a/authentik/tasks/broker.py +++ b/authentik/tasks/broker.py @@ -1,3 +1,4 @@ +from psycopg import sql import functools import logging import time @@ -118,7 +119,7 @@ class PostgresBroker(Broker): self.declare_queue(canonical_queue_name) self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}") self.emit_before("enqueue", message, delay) - encoded = message.encode() + encoded = message.encode().decode() query = { "message_id": message.message_id, } @@ -212,7 +213,12 @@ class _PostgresConsumer(Consumer): # Should be set to True by Django by default self._listen_connection.set_autocommit(True) with self._listen_connection.cursor() as cursor: - cursor.execute("LISTEN %s", channel_name(self.queue_name, ChannelIdentifier.ENQUEUE)) + cursor.execute( + sql.SQL("LISTEN {}").format( + sql.Identifier(channel_name(self.queue_name, ChannelIdentifier.ENQUEUE)) + ) + ) + return self._listen_connection @raise_connection_error def ack(self, message: Message): @@ -223,7 +229,7 @@ class _PostgresConsumer(Consumer): state=TaskState.CONSUMED, ).update( state=TaskState.DONE, - message=message.encode(), + message=message.encode().decode(), ) self.in_processing.remove(message.message_id) @@ -236,7 +242,7 @@ class _PostgresConsumer(Consumer): state__ne=TaskState.REJECTED, ).update( state=TaskState.REJECTED, - message=message.encode(), + message=message.encode().decode(), ) self.in_processing.remove(message.message_id) @@ -259,7 +265,7 @@ class _PostgresConsumer(Consumer): def _poll_for_notify(self): with self.listen_connection.cursor() as cursor: - notifies = cursor.notifies(timeout=self.timeout) + notifies = list(cursor.connection.notifies(timeout=self.timeout)) self.logger.debug(f"Received {len(notifies)} postgres notifies") self.notifies += notifies @@ -278,8 +284,8 @@ class _PostgresConsumer(Consumer): message_id=message.message_id, state__in=(TaskState.QUEUED, TaskState.CONSUMED), ) - .update(state=TaskState.CONSUMED, mtime=timezone.now()) .extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)]) + .update(state=TaskState.CONSUMED, mtime=timezone.now()) ) return result == 1 @@ -334,7 +340,7 @@ class _PostgresConsumer(Consumer): # No message to process self._purge_locks() - self._auto_pruge() + self._auto_purge() def _purge_locks(self): while True: @@ -344,7 +350,9 @@ class _PostgresConsumer(Consumer): return self.logger.debug(f"Unlocking {message.message_id}@{message.queue_name}") with self.connection.cursor() as cursor: - cursor.execute("SELECT pg_advisory_unlock(%s)", self._get_message_lock_id(message)) + cursor.execute( + "SELECT pg_advisory_unlock(%s)", (self._get_message_lock_id(message),) + ) self.unlock_queue.task_done() def _auto_purge(self): diff --git a/authentik/tasks/encoder.py b/authentik/tasks/encoder.py index 833b625d4b..c6fb399fe6 100644 --- a/authentik/tasks/encoder.py +++ b/authentik/tasks/encoder.py @@ -21,11 +21,11 @@ class JSONPickleEncoder(dramatiq.encoder.Encoder): keys=True, warn=True, use_base85=True, - ).encode("utf-8") + ).encode() def decode(self, data: bytes) -> MessageData: return jsonpickle.decode( - data.decode("utf-8"), + data.decode(), backend=OrjsonBackend(), keys=True, on_missing="warn", diff --git a/authentik/tasks/management/__init__.py b/authentik/tasks/management/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/authentik/tasks/management/commands/__init__.py b/authentik/tasks/management/commands/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/authentik/tasks/management/commands/worker.py b/authentik/tasks/management/commands/worker.py new file mode 100644 index 0000000000..93b48bf356 --- /dev/null +++ b/authentik/tasks/management/commands/worker.py @@ -0,0 +1,86 @@ +import os + +from django.utils.module_loading import module_has_submodule +from authentik.lib.utils.reflection import get_apps +import sys + +from django.core.management.base import BaseCommand + + +class Command(BaseCommand): + """Run worker""" + + def add_arguments(self, parser): + parser.add_argument( + "--reload", + action="store_true", + dest="use_watcher", + help="Enable autoreload", + ) + parser.add_argument( + "--reload-use-polling", + action="store_true", + dest="use_polling_watcher", + help="Use a poll-based file watcher for autoreload", + ) + parser.add_argument( + "--use-gevent", + action="store_true", + help="Use gevent for worker concurrency", + ) + parser.add_argument( + "--processes", + "-p", + default=1, + type=int, + help="The number of processes to run", + ) + parser.add_argument( + "--threads", + "-t", + default=1, + type=int, + help="The number of threads per process to use", + ) + + def handle( + self, use_watcher, use_polling_watcher, use_gevent, processes, threads, verbosity, **options + ): + executable_name = "dramatiq-gevent" if use_gevent else "dramatiq" + executable_path = self._resolve_executable(executable_name) + watch_args = ["--watch", "."] if use_watcher else [] + if watch_args and use_polling_watcher: + watch_args.append("--watch-use-polling") + + verbosity_args = ["-v"] * (verbosity - 1) + + tasks_modules = self._discover_tasks_modules() + process_args = [ + executable_name, + "--path", + ".", + "--processes", + str(processes), + "--threads", + str(threads), + *watch_args, + *verbosity_args, + *tasks_modules, + ] + + os.execvp(executable_path, process_args) + + def _resolve_executable(self, exec_name: str): + bin_dir = os.path.dirname(sys.executable) + if bin_dir: + for d in [bin_dir, os.path.join(bin_dir, "Scripts")]: + exec_path = os.path.join(d, exec_name) + if os.path.isfile(exec_path): + return exec_path + return exec_name + + def _discover_tasks_modules(self) -> list[str]: + # Does not support a tasks directory + return ["authentik.tasks.setup"] + [ + f"{app.name}.tasks" for app in get_apps() if module_has_submodule(app.module, "tasks") + ] diff --git a/authentik/tasks/migrations/0001_initial.py b/authentik/tasks/migrations/0001_initial.py index 2d68423bf7..cd8e300ac9 100644 --- a/authentik/tasks/migrations/0001_initial.py +++ b/authentik/tasks/migrations/0001_initial.py @@ -62,7 +62,7 @@ class Migration(migrations.Migration): }, ), migrations.RunSQL( - "ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop + "ALTER TABLE authentik_tasks_task SET WITHOUT OIDS;", migrations.RunSQL.noop ), pgtrigger.migrations.AddTrigger( model_name="task", diff --git a/authentik/tasks/models.py b/authentik/tasks/models.py index d5a57e0b0d..c8cd969fd0 100644 --- a/authentik/tasks/models.py +++ b/authentik/tasks/models.py @@ -48,8 +48,8 @@ class Task(SerializerModel): condition=pgtrigger.Q(new__state=TaskState.QUEUED), timing=pgtrigger.Deferred, func=f""" - PERFORM pg_notify( - '{CHANNEL_PREFIX}' || NEW.queue_name || '.{ChannelIdentifier.ENQUEUE.value}', + SELECT pg_notify( + '{CHANNEL_PREFIX}.' || NEW.queue_name || '.{ChannelIdentifier.ENQUEUE.value}', CASE WHEN octet_length(NEW.message::text) >= 8000 THEN jsonb_build_object('message_id', NEW.message_id)::text ELSE NEW.message::text diff --git a/authentik/tasks/setup.py b/authentik/tasks/setup.py new file mode 100644 index 0000000000..92b9fb83bb --- /dev/null +++ b/authentik/tasks/setup.py @@ -0,0 +1,45 @@ +import os +import sys +import warnings + +from authentik.lib.config import CONFIG +from cryptography.hazmat.backends.openssl.backend import backend +from defusedxml import defuse_stdlib +from django.utils.autoreload import DJANGO_AUTORELOAD_ENV + +from lifecycle.migrate import run_migrations +from lifecycle.wait_for_db import wait_for_db + +warnings.filterwarnings("ignore", "SelectableGroups dict interface") +warnings.filterwarnings( + "ignore", + "defusedxml.lxml is no longer supported and will be removed in a future release.", +) +warnings.filterwarnings( + "ignore", + "defusedxml.cElementTree is deprecated, import from defusedxml.ElementTree instead.", +) + +defuse_stdlib() + +if CONFIG.get_bool("compliance.fips.enabled", False): + backend._enable_fips() + + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings") +wait_for_db() +print(sys.argv) +if ( + len(sys.argv) > 1 + # Explicitly only run migrate for server and worker + # `bootstrap_tasks` is a special case as that command might be triggered by the `ak` + # script to pre-run certain tasks for an automated install + and sys.argv[1] in ["dev_server", "worker", "bootstrap_tasks"] + # and don't run if this is the child process of a dev_server + and os.environ.get(DJANGO_AUTORELOAD_ENV, None) is None +): + run_migrations() + +import django + +django.setup() diff --git a/authentik/tasks/tasks.py b/authentik/tasks/tasks.py new file mode 100644 index 0000000000..9b2896b426 --- /dev/null +++ b/authentik/tasks/tasks.py @@ -0,0 +1,9 @@ +from dramatiq import actor + + +@actor +def test_actor(): + import time + + time.sleep(5) + print("done sleeping")