@ -1,47 +0,0 @@
|
||||
"""Run worker"""
|
||||
|
||||
from sys import exit as sysexit
|
||||
from tempfile import tempdir
|
||||
|
||||
from celery.apps.worker import Worker
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db import close_old_connections
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.debug import start_debug_server
|
||||
from authentik.root.celery import CELERY_APP
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
"""Run worker"""
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--beat",
|
||||
action="store_false",
|
||||
help="When set, this worker will _not_ run Beat (scheduled) tasks",
|
||||
)
|
||||
|
||||
def handle(self, **options):
|
||||
LOGGER.debug("Celery options", **options)
|
||||
close_old_connections()
|
||||
start_debug_server()
|
||||
worker: Worker = CELERY_APP.Worker(
|
||||
no_color=False,
|
||||
quiet=True,
|
||||
optimization="fair",
|
||||
autoscale=(CONFIG.get_int("worker.concurrency"), 1),
|
||||
task_events=True,
|
||||
beat=options.get("beat", True),
|
||||
schedule_filename=f"{tempdir}/celerybeat-schedule",
|
||||
queues=["authentik", "authentik_scheduled", "authentik_events"],
|
||||
)
|
||||
for task in CELERY_APP.tasks:
|
||||
LOGGER.debug("Registered task", task=task)
|
||||
|
||||
worker.start()
|
||||
sysexit(worker.exitcode)
|
@ -17,7 +17,7 @@ class AuthentikTasksConfig(ManagedAppConfig):
|
||||
|
||||
dramatiq.set_encoder(JSONPickleEncoder())
|
||||
broker = PostgresBroker()
|
||||
broker.add_middleware(Prometheus())
|
||||
# broker.add_middleware(Prometheus())
|
||||
broker.add_middleware(AgeLimit(max_age=timedelta(days=30).total_seconds() * 1000))
|
||||
broker.add_middleware(TimeLimit())
|
||||
broker.add_middleware(Callbacks())
|
||||
|
@ -1,3 +1,4 @@
|
||||
from psycopg import sql
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
@ -118,7 +119,7 @@ class PostgresBroker(Broker):
|
||||
self.declare_queue(canonical_queue_name)
|
||||
self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}")
|
||||
self.emit_before("enqueue", message, delay)
|
||||
encoded = message.encode()
|
||||
encoded = message.encode().decode()
|
||||
query = {
|
||||
"message_id": message.message_id,
|
||||
}
|
||||
@ -212,7 +213,12 @@ class _PostgresConsumer(Consumer):
|
||||
# Should be set to True by Django by default
|
||||
self._listen_connection.set_autocommit(True)
|
||||
with self._listen_connection.cursor() as cursor:
|
||||
cursor.execute("LISTEN %s", channel_name(self.queue_name, ChannelIdentifier.ENQUEUE))
|
||||
cursor.execute(
|
||||
sql.SQL("LISTEN {}").format(
|
||||
sql.Identifier(channel_name(self.queue_name, ChannelIdentifier.ENQUEUE))
|
||||
)
|
||||
)
|
||||
return self._listen_connection
|
||||
|
||||
@raise_connection_error
|
||||
def ack(self, message: Message):
|
||||
@ -223,7 +229,7 @@ class _PostgresConsumer(Consumer):
|
||||
state=TaskState.CONSUMED,
|
||||
).update(
|
||||
state=TaskState.DONE,
|
||||
message=message.encode(),
|
||||
message=message.encode().decode(),
|
||||
)
|
||||
self.in_processing.remove(message.message_id)
|
||||
|
||||
@ -236,7 +242,7 @@ class _PostgresConsumer(Consumer):
|
||||
state__ne=TaskState.REJECTED,
|
||||
).update(
|
||||
state=TaskState.REJECTED,
|
||||
message=message.encode(),
|
||||
message=message.encode().decode(),
|
||||
)
|
||||
self.in_processing.remove(message.message_id)
|
||||
|
||||
@ -259,7 +265,7 @@ class _PostgresConsumer(Consumer):
|
||||
|
||||
def _poll_for_notify(self):
|
||||
with self.listen_connection.cursor() as cursor:
|
||||
notifies = cursor.notifies(timeout=self.timeout)
|
||||
notifies = list(cursor.connection.notifies(timeout=self.timeout))
|
||||
self.logger.debug(f"Received {len(notifies)} postgres notifies")
|
||||
self.notifies += notifies
|
||||
|
||||
@ -278,8 +284,8 @@ class _PostgresConsumer(Consumer):
|
||||
message_id=message.message_id,
|
||||
state__in=(TaskState.QUEUED, TaskState.CONSUMED),
|
||||
)
|
||||
.update(state=TaskState.CONSUMED, mtime=timezone.now())
|
||||
.extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)])
|
||||
.update(state=TaskState.CONSUMED, mtime=timezone.now())
|
||||
)
|
||||
return result == 1
|
||||
|
||||
@ -334,7 +340,7 @@ class _PostgresConsumer(Consumer):
|
||||
|
||||
# No message to process
|
||||
self._purge_locks()
|
||||
self._auto_pruge()
|
||||
self._auto_purge()
|
||||
|
||||
def _purge_locks(self):
|
||||
while True:
|
||||
@ -344,7 +350,9 @@ class _PostgresConsumer(Consumer):
|
||||
return
|
||||
self.logger.debug(f"Unlocking {message.message_id}@{message.queue_name}")
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute("SELECT pg_advisory_unlock(%s)", self._get_message_lock_id(message))
|
||||
cursor.execute(
|
||||
"SELECT pg_advisory_unlock(%s)", (self._get_message_lock_id(message),)
|
||||
)
|
||||
self.unlock_queue.task_done()
|
||||
|
||||
def _auto_purge(self):
|
||||
|
@ -21,11 +21,11 @@ class JSONPickleEncoder(dramatiq.encoder.Encoder):
|
||||
keys=True,
|
||||
warn=True,
|
||||
use_base85=True,
|
||||
).encode("utf-8")
|
||||
).encode()
|
||||
|
||||
def decode(self, data: bytes) -> MessageData:
|
||||
return jsonpickle.decode(
|
||||
data.decode("utf-8"),
|
||||
data.decode(),
|
||||
backend=OrjsonBackend(),
|
||||
keys=True,
|
||||
on_missing="warn",
|
||||
|
0
authentik/tasks/management/__init__.py
Normal file
0
authentik/tasks/management/__init__.py
Normal file
0
authentik/tasks/management/commands/__init__.py
Normal file
0
authentik/tasks/management/commands/__init__.py
Normal file
86
authentik/tasks/management/commands/worker.py
Normal file
86
authentik/tasks/management/commands/worker.py
Normal file
@ -0,0 +1,86 @@
|
||||
import os
|
||||
|
||||
from django.utils.module_loading import module_has_submodule
|
||||
from authentik.lib.utils.reflection import get_apps
|
||||
import sys
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
"""Run worker"""
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
"--reload",
|
||||
action="store_true",
|
||||
dest="use_watcher",
|
||||
help="Enable autoreload",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reload-use-polling",
|
||||
action="store_true",
|
||||
dest="use_polling_watcher",
|
||||
help="Use a poll-based file watcher for autoreload",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-gevent",
|
||||
action="store_true",
|
||||
help="Use gevent for worker concurrency",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--processes",
|
||||
"-p",
|
||||
default=1,
|
||||
type=int,
|
||||
help="The number of processes to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threads",
|
||||
"-t",
|
||||
default=1,
|
||||
type=int,
|
||||
help="The number of threads per process to use",
|
||||
)
|
||||
|
||||
def handle(
|
||||
self, use_watcher, use_polling_watcher, use_gevent, processes, threads, verbosity, **options
|
||||
):
|
||||
executable_name = "dramatiq-gevent" if use_gevent else "dramatiq"
|
||||
executable_path = self._resolve_executable(executable_name)
|
||||
watch_args = ["--watch", "."] if use_watcher else []
|
||||
if watch_args and use_polling_watcher:
|
||||
watch_args.append("--watch-use-polling")
|
||||
|
||||
verbosity_args = ["-v"] * (verbosity - 1)
|
||||
|
||||
tasks_modules = self._discover_tasks_modules()
|
||||
process_args = [
|
||||
executable_name,
|
||||
"--path",
|
||||
".",
|
||||
"--processes",
|
||||
str(processes),
|
||||
"--threads",
|
||||
str(threads),
|
||||
*watch_args,
|
||||
*verbosity_args,
|
||||
*tasks_modules,
|
||||
]
|
||||
|
||||
os.execvp(executable_path, process_args)
|
||||
|
||||
def _resolve_executable(self, exec_name: str):
|
||||
bin_dir = os.path.dirname(sys.executable)
|
||||
if bin_dir:
|
||||
for d in [bin_dir, os.path.join(bin_dir, "Scripts")]:
|
||||
exec_path = os.path.join(d, exec_name)
|
||||
if os.path.isfile(exec_path):
|
||||
return exec_path
|
||||
return exec_name
|
||||
|
||||
def _discover_tasks_modules(self) -> list[str]:
|
||||
# Does not support a tasks directory
|
||||
return ["authentik.tasks.setup"] + [
|
||||
f"{app.name}.tasks" for app in get_apps() if module_has_submodule(app.module, "tasks")
|
||||
]
|
@ -62,7 +62,7 @@ class Migration(migrations.Migration):
|
||||
},
|
||||
),
|
||||
migrations.RunSQL(
|
||||
"ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop
|
||||
"ALTER TABLE authentik_tasks_task SET WITHOUT OIDS;", migrations.RunSQL.noop
|
||||
),
|
||||
pgtrigger.migrations.AddTrigger(
|
||||
model_name="task",
|
||||
|
@ -48,8 +48,8 @@ class Task(SerializerModel):
|
||||
condition=pgtrigger.Q(new__state=TaskState.QUEUED),
|
||||
timing=pgtrigger.Deferred,
|
||||
func=f"""
|
||||
PERFORM pg_notify(
|
||||
'{CHANNEL_PREFIX}' || NEW.queue_name || '.{ChannelIdentifier.ENQUEUE.value}',
|
||||
SELECT pg_notify(
|
||||
'{CHANNEL_PREFIX}.' || NEW.queue_name || '.{ChannelIdentifier.ENQUEUE.value}',
|
||||
CASE WHEN octet_length(NEW.message::text) >= 8000
|
||||
THEN jsonb_build_object('message_id', NEW.message_id)::text
|
||||
ELSE NEW.message::text
|
||||
|
45
authentik/tasks/setup.py
Normal file
45
authentik/tasks/setup.py
Normal file
@ -0,0 +1,45 @@
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
from cryptography.hazmat.backends.openssl.backend import backend
|
||||
from defusedxml import defuse_stdlib
|
||||
from django.utils.autoreload import DJANGO_AUTORELOAD_ENV
|
||||
|
||||
from lifecycle.migrate import run_migrations
|
||||
from lifecycle.wait_for_db import wait_for_db
|
||||
|
||||
warnings.filterwarnings("ignore", "SelectableGroups dict interface")
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
"defusedxml.lxml is no longer supported and will be removed in a future release.",
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
"defusedxml.cElementTree is deprecated, import from defusedxml.ElementTree instead.",
|
||||
)
|
||||
|
||||
defuse_stdlib()
|
||||
|
||||
if CONFIG.get_bool("compliance.fips.enabled", False):
|
||||
backend._enable_fips()
|
||||
|
||||
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
|
||||
wait_for_db()
|
||||
print(sys.argv)
|
||||
if (
|
||||
len(sys.argv) > 1
|
||||
# Explicitly only run migrate for server and worker
|
||||
# `bootstrap_tasks` is a special case as that command might be triggered by the `ak`
|
||||
# script to pre-run certain tasks for an automated install
|
||||
and sys.argv[1] in ["dev_server", "worker", "bootstrap_tasks"]
|
||||
# and don't run if this is the child process of a dev_server
|
||||
and os.environ.get(DJANGO_AUTORELOAD_ENV, None) is None
|
||||
):
|
||||
run_migrations()
|
||||
|
||||
import django
|
||||
|
||||
django.setup()
|
9
authentik/tasks/tasks.py
Normal file
9
authentik/tasks/tasks.py
Normal file
@ -0,0 +1,9 @@
|
||||
from dramatiq import actor
|
||||
|
||||
|
||||
@actor
|
||||
def test_actor():
|
||||
import time
|
||||
|
||||
time.sleep(5)
|
||||
print("done sleeping")
|
Reference in New Issue
Block a user