@ -1,47 +0,0 @@
|
|||||||
"""Run worker"""
|
|
||||||
|
|
||||||
from sys import exit as sysexit
|
|
||||||
from tempfile import tempdir
|
|
||||||
|
|
||||||
from celery.apps.worker import Worker
|
|
||||||
from django.core.management.base import BaseCommand
|
|
||||||
from django.db import close_old_connections
|
|
||||||
from structlog.stdlib import get_logger
|
|
||||||
|
|
||||||
from authentik.lib.config import CONFIG
|
|
||||||
from authentik.lib.debug import start_debug_server
|
|
||||||
from authentik.root.celery import CELERY_APP
|
|
||||||
|
|
||||||
LOGGER = get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
class Command(BaseCommand):
|
|
||||||
"""Run worker"""
|
|
||||||
|
|
||||||
def add_arguments(self, parser):
|
|
||||||
parser.add_argument(
|
|
||||||
"-b",
|
|
||||||
"--beat",
|
|
||||||
action="store_false",
|
|
||||||
help="When set, this worker will _not_ run Beat (scheduled) tasks",
|
|
||||||
)
|
|
||||||
|
|
||||||
def handle(self, **options):
|
|
||||||
LOGGER.debug("Celery options", **options)
|
|
||||||
close_old_connections()
|
|
||||||
start_debug_server()
|
|
||||||
worker: Worker = CELERY_APP.Worker(
|
|
||||||
no_color=False,
|
|
||||||
quiet=True,
|
|
||||||
optimization="fair",
|
|
||||||
autoscale=(CONFIG.get_int("worker.concurrency"), 1),
|
|
||||||
task_events=True,
|
|
||||||
beat=options.get("beat", True),
|
|
||||||
schedule_filename=f"{tempdir}/celerybeat-schedule",
|
|
||||||
queues=["authentik", "authentik_scheduled", "authentik_events"],
|
|
||||||
)
|
|
||||||
for task in CELERY_APP.tasks:
|
|
||||||
LOGGER.debug("Registered task", task=task)
|
|
||||||
|
|
||||||
worker.start()
|
|
||||||
sysexit(worker.exitcode)
|
|
@ -17,7 +17,7 @@ class AuthentikTasksConfig(ManagedAppConfig):
|
|||||||
|
|
||||||
dramatiq.set_encoder(JSONPickleEncoder())
|
dramatiq.set_encoder(JSONPickleEncoder())
|
||||||
broker = PostgresBroker()
|
broker = PostgresBroker()
|
||||||
broker.add_middleware(Prometheus())
|
# broker.add_middleware(Prometheus())
|
||||||
broker.add_middleware(AgeLimit(max_age=timedelta(days=30).total_seconds() * 1000))
|
broker.add_middleware(AgeLimit(max_age=timedelta(days=30).total_seconds() * 1000))
|
||||||
broker.add_middleware(TimeLimit())
|
broker.add_middleware(TimeLimit())
|
||||||
broker.add_middleware(Callbacks())
|
broker.add_middleware(Callbacks())
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from psycopg import sql
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
@ -118,7 +119,7 @@ class PostgresBroker(Broker):
|
|||||||
self.declare_queue(canonical_queue_name)
|
self.declare_queue(canonical_queue_name)
|
||||||
self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}")
|
self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}")
|
||||||
self.emit_before("enqueue", message, delay)
|
self.emit_before("enqueue", message, delay)
|
||||||
encoded = message.encode()
|
encoded = message.encode().decode()
|
||||||
query = {
|
query = {
|
||||||
"message_id": message.message_id,
|
"message_id": message.message_id,
|
||||||
}
|
}
|
||||||
@ -212,7 +213,12 @@ class _PostgresConsumer(Consumer):
|
|||||||
# Should be set to True by Django by default
|
# Should be set to True by Django by default
|
||||||
self._listen_connection.set_autocommit(True)
|
self._listen_connection.set_autocommit(True)
|
||||||
with self._listen_connection.cursor() as cursor:
|
with self._listen_connection.cursor() as cursor:
|
||||||
cursor.execute("LISTEN %s", channel_name(self.queue_name, ChannelIdentifier.ENQUEUE))
|
cursor.execute(
|
||||||
|
sql.SQL("LISTEN {}").format(
|
||||||
|
sql.Identifier(channel_name(self.queue_name, ChannelIdentifier.ENQUEUE))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return self._listen_connection
|
||||||
|
|
||||||
@raise_connection_error
|
@raise_connection_error
|
||||||
def ack(self, message: Message):
|
def ack(self, message: Message):
|
||||||
@ -223,7 +229,7 @@ class _PostgresConsumer(Consumer):
|
|||||||
state=TaskState.CONSUMED,
|
state=TaskState.CONSUMED,
|
||||||
).update(
|
).update(
|
||||||
state=TaskState.DONE,
|
state=TaskState.DONE,
|
||||||
message=message.encode(),
|
message=message.encode().decode(),
|
||||||
)
|
)
|
||||||
self.in_processing.remove(message.message_id)
|
self.in_processing.remove(message.message_id)
|
||||||
|
|
||||||
@ -236,7 +242,7 @@ class _PostgresConsumer(Consumer):
|
|||||||
state__ne=TaskState.REJECTED,
|
state__ne=TaskState.REJECTED,
|
||||||
).update(
|
).update(
|
||||||
state=TaskState.REJECTED,
|
state=TaskState.REJECTED,
|
||||||
message=message.encode(),
|
message=message.encode().decode(),
|
||||||
)
|
)
|
||||||
self.in_processing.remove(message.message_id)
|
self.in_processing.remove(message.message_id)
|
||||||
|
|
||||||
@ -259,7 +265,7 @@ class _PostgresConsumer(Consumer):
|
|||||||
|
|
||||||
def _poll_for_notify(self):
|
def _poll_for_notify(self):
|
||||||
with self.listen_connection.cursor() as cursor:
|
with self.listen_connection.cursor() as cursor:
|
||||||
notifies = cursor.notifies(timeout=self.timeout)
|
notifies = list(cursor.connection.notifies(timeout=self.timeout))
|
||||||
self.logger.debug(f"Received {len(notifies)} postgres notifies")
|
self.logger.debug(f"Received {len(notifies)} postgres notifies")
|
||||||
self.notifies += notifies
|
self.notifies += notifies
|
||||||
|
|
||||||
@ -278,8 +284,8 @@ class _PostgresConsumer(Consumer):
|
|||||||
message_id=message.message_id,
|
message_id=message.message_id,
|
||||||
state__in=(TaskState.QUEUED, TaskState.CONSUMED),
|
state__in=(TaskState.QUEUED, TaskState.CONSUMED),
|
||||||
)
|
)
|
||||||
.update(state=TaskState.CONSUMED, mtime=timezone.now())
|
|
||||||
.extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)])
|
.extra(where=["pg_try_advisory_lock(%s)"], params=[self._get_message_lock_id(message)])
|
||||||
|
.update(state=TaskState.CONSUMED, mtime=timezone.now())
|
||||||
)
|
)
|
||||||
return result == 1
|
return result == 1
|
||||||
|
|
||||||
@ -334,7 +340,7 @@ class _PostgresConsumer(Consumer):
|
|||||||
|
|
||||||
# No message to process
|
# No message to process
|
||||||
self._purge_locks()
|
self._purge_locks()
|
||||||
self._auto_pruge()
|
self._auto_purge()
|
||||||
|
|
||||||
def _purge_locks(self):
|
def _purge_locks(self):
|
||||||
while True:
|
while True:
|
||||||
@ -344,7 +350,9 @@ class _PostgresConsumer(Consumer):
|
|||||||
return
|
return
|
||||||
self.logger.debug(f"Unlocking {message.message_id}@{message.queue_name}")
|
self.logger.debug(f"Unlocking {message.message_id}@{message.queue_name}")
|
||||||
with self.connection.cursor() as cursor:
|
with self.connection.cursor() as cursor:
|
||||||
cursor.execute("SELECT pg_advisory_unlock(%s)", self._get_message_lock_id(message))
|
cursor.execute(
|
||||||
|
"SELECT pg_advisory_unlock(%s)", (self._get_message_lock_id(message),)
|
||||||
|
)
|
||||||
self.unlock_queue.task_done()
|
self.unlock_queue.task_done()
|
||||||
|
|
||||||
def _auto_purge(self):
|
def _auto_purge(self):
|
||||||
|
@ -21,11 +21,11 @@ class JSONPickleEncoder(dramatiq.encoder.Encoder):
|
|||||||
keys=True,
|
keys=True,
|
||||||
warn=True,
|
warn=True,
|
||||||
use_base85=True,
|
use_base85=True,
|
||||||
).encode("utf-8")
|
).encode()
|
||||||
|
|
||||||
def decode(self, data: bytes) -> MessageData:
|
def decode(self, data: bytes) -> MessageData:
|
||||||
return jsonpickle.decode(
|
return jsonpickle.decode(
|
||||||
data.decode("utf-8"),
|
data.decode(),
|
||||||
backend=OrjsonBackend(),
|
backend=OrjsonBackend(),
|
||||||
keys=True,
|
keys=True,
|
||||||
on_missing="warn",
|
on_missing="warn",
|
||||||
|
0
authentik/tasks/management/__init__.py
Normal file
0
authentik/tasks/management/__init__.py
Normal file
0
authentik/tasks/management/commands/__init__.py
Normal file
0
authentik/tasks/management/commands/__init__.py
Normal file
86
authentik/tasks/management/commands/worker.py
Normal file
86
authentik/tasks/management/commands/worker.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from django.utils.module_loading import module_has_submodule
|
||||||
|
from authentik.lib.utils.reflection import get_apps
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from django.core.management.base import BaseCommand
|
||||||
|
|
||||||
|
|
||||||
|
class Command(BaseCommand):
|
||||||
|
"""Run worker"""
|
||||||
|
|
||||||
|
def add_arguments(self, parser):
|
||||||
|
parser.add_argument(
|
||||||
|
"--reload",
|
||||||
|
action="store_true",
|
||||||
|
dest="use_watcher",
|
||||||
|
help="Enable autoreload",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reload-use-polling",
|
||||||
|
action="store_true",
|
||||||
|
dest="use_polling_watcher",
|
||||||
|
help="Use a poll-based file watcher for autoreload",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-gevent",
|
||||||
|
action="store_true",
|
||||||
|
help="Use gevent for worker concurrency",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--processes",
|
||||||
|
"-p",
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help="The number of processes to run",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--threads",
|
||||||
|
"-t",
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help="The number of threads per process to use",
|
||||||
|
)
|
||||||
|
|
||||||
|
def handle(
|
||||||
|
self, use_watcher, use_polling_watcher, use_gevent, processes, threads, verbosity, **options
|
||||||
|
):
|
||||||
|
executable_name = "dramatiq-gevent" if use_gevent else "dramatiq"
|
||||||
|
executable_path = self._resolve_executable(executable_name)
|
||||||
|
watch_args = ["--watch", "."] if use_watcher else []
|
||||||
|
if watch_args and use_polling_watcher:
|
||||||
|
watch_args.append("--watch-use-polling")
|
||||||
|
|
||||||
|
verbosity_args = ["-v"] * (verbosity - 1)
|
||||||
|
|
||||||
|
tasks_modules = self._discover_tasks_modules()
|
||||||
|
process_args = [
|
||||||
|
executable_name,
|
||||||
|
"--path",
|
||||||
|
".",
|
||||||
|
"--processes",
|
||||||
|
str(processes),
|
||||||
|
"--threads",
|
||||||
|
str(threads),
|
||||||
|
*watch_args,
|
||||||
|
*verbosity_args,
|
||||||
|
*tasks_modules,
|
||||||
|
]
|
||||||
|
|
||||||
|
os.execvp(executable_path, process_args)
|
||||||
|
|
||||||
|
def _resolve_executable(self, exec_name: str):
|
||||||
|
bin_dir = os.path.dirname(sys.executable)
|
||||||
|
if bin_dir:
|
||||||
|
for d in [bin_dir, os.path.join(bin_dir, "Scripts")]:
|
||||||
|
exec_path = os.path.join(d, exec_name)
|
||||||
|
if os.path.isfile(exec_path):
|
||||||
|
return exec_path
|
||||||
|
return exec_name
|
||||||
|
|
||||||
|
def _discover_tasks_modules(self) -> list[str]:
|
||||||
|
# Does not support a tasks directory
|
||||||
|
return ["authentik.tasks.setup"] + [
|
||||||
|
f"{app.name}.tasks" for app in get_apps() if module_has_submodule(app.module, "tasks")
|
||||||
|
]
|
@ -62,7 +62,7 @@ class Migration(migrations.Migration):
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
migrations.RunSQL(
|
migrations.RunSQL(
|
||||||
"ALTER TABLE authentik_tasks_queue SET WITHOUT OIDS;", migrations.RunSQL.noop
|
"ALTER TABLE authentik_tasks_task SET WITHOUT OIDS;", migrations.RunSQL.noop
|
||||||
),
|
),
|
||||||
pgtrigger.migrations.AddTrigger(
|
pgtrigger.migrations.AddTrigger(
|
||||||
model_name="task",
|
model_name="task",
|
||||||
|
@ -48,8 +48,8 @@ class Task(SerializerModel):
|
|||||||
condition=pgtrigger.Q(new__state=TaskState.QUEUED),
|
condition=pgtrigger.Q(new__state=TaskState.QUEUED),
|
||||||
timing=pgtrigger.Deferred,
|
timing=pgtrigger.Deferred,
|
||||||
func=f"""
|
func=f"""
|
||||||
PERFORM pg_notify(
|
SELECT pg_notify(
|
||||||
'{CHANNEL_PREFIX}' || NEW.queue_name || '.{ChannelIdentifier.ENQUEUE.value}',
|
'{CHANNEL_PREFIX}.' || NEW.queue_name || '.{ChannelIdentifier.ENQUEUE.value}',
|
||||||
CASE WHEN octet_length(NEW.message::text) >= 8000
|
CASE WHEN octet_length(NEW.message::text) >= 8000
|
||||||
THEN jsonb_build_object('message_id', NEW.message_id)::text
|
THEN jsonb_build_object('message_id', NEW.message_id)::text
|
||||||
ELSE NEW.message::text
|
ELSE NEW.message::text
|
||||||
|
45
authentik/tasks/setup.py
Normal file
45
authentik/tasks/setup.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from authentik.lib.config import CONFIG
|
||||||
|
from cryptography.hazmat.backends.openssl.backend import backend
|
||||||
|
from defusedxml import defuse_stdlib
|
||||||
|
from django.utils.autoreload import DJANGO_AUTORELOAD_ENV
|
||||||
|
|
||||||
|
from lifecycle.migrate import run_migrations
|
||||||
|
from lifecycle.wait_for_db import wait_for_db
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore", "SelectableGroups dict interface")
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
"defusedxml.lxml is no longer supported and will be removed in a future release.",
|
||||||
|
)
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
"defusedxml.cElementTree is deprecated, import from defusedxml.ElementTree instead.",
|
||||||
|
)
|
||||||
|
|
||||||
|
defuse_stdlib()
|
||||||
|
|
||||||
|
if CONFIG.get_bool("compliance.fips.enabled", False):
|
||||||
|
backend._enable_fips()
|
||||||
|
|
||||||
|
|
||||||
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
|
||||||
|
wait_for_db()
|
||||||
|
print(sys.argv)
|
||||||
|
if (
|
||||||
|
len(sys.argv) > 1
|
||||||
|
# Explicitly only run migrate for server and worker
|
||||||
|
# `bootstrap_tasks` is a special case as that command might be triggered by the `ak`
|
||||||
|
# script to pre-run certain tasks for an automated install
|
||||||
|
and sys.argv[1] in ["dev_server", "worker", "bootstrap_tasks"]
|
||||||
|
# and don't run if this is the child process of a dev_server
|
||||||
|
and os.environ.get(DJANGO_AUTORELOAD_ENV, None) is None
|
||||||
|
):
|
||||||
|
run_migrations()
|
||||||
|
|
||||||
|
import django
|
||||||
|
|
||||||
|
django.setup()
|
9
authentik/tasks/tasks.py
Normal file
9
authentik/tasks/tasks.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from dramatiq import actor
|
||||||
|
|
||||||
|
|
||||||
|
@actor
|
||||||
|
def test_actor():
|
||||||
|
import time
|
||||||
|
|
||||||
|
time.sleep(5)
|
||||||
|
print("done sleeping")
|
Reference in New Issue
Block a user