diff --git a/authentik/lib/default.yml b/authentik/lib/default.yml index 751cf55f0b..16381470a0 100644 --- a/authentik/lib/default.yml +++ b/authentik/lib/default.yml @@ -159,6 +159,7 @@ web: worker: processes: 2 threads: 1 + consumer_listen_timeout: 30 storage: media: diff --git a/authentik/root/settings.py b/authentik/root/settings.py index 9f3b42fc77..3fe9fa3e69 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -357,7 +357,7 @@ TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner" DRAMATIQ = { "broker_class": "authentik.tasks.broker.Broker", "channel_prefix": "authentik", - "task_class": "authentik.tasks.models.Task", + "task_model": "authentik.tasks.models.Task", "autodiscovery": { "enabled": True, "setup_module": "authentik.tasks.setup", @@ -366,8 +366,12 @@ DRAMATIQ = { "worker": { "processes": CONFIG.get_int("worker.processes", 2), "threads": CONFIG.get_int("worker.threads", 1), + "consumer_listen_timeout": CONFIG.get_int("worker.consumer_listen_timeout", 30), }, + "scheduler_class": "authentik.tasks.schedules.scheduler.Scheduler", + "schedule_model": "authentik.tasks.schedules.models.Schedule", "middlewares": ( + ("django_dramatiq_postgres.middleware.SchedulerMiddleware", {}), ("django_dramatiq_postgres.middleware.FullyQualifiedActorName", {}), # TODO: fixme # ("dramatiq.middleware.prometheus.Prometheus", {}), diff --git a/authentik/tasks/schedules/scheduler.py b/authentik/tasks/schedules/scheduler.py index ee135f4789..021a7df69d 100644 --- a/authentik/tasks/schedules/scheduler.py +++ b/authentik/tasks/schedules/scheduler.py @@ -1,3 +1,5 @@ +from time import sleep +from django_dramatiq_postgres.conf import Conf import pglock from django_dramatiq_postgres.scheduler import Scheduler as SchedulerBase from structlog.stdlib import get_logger @@ -20,5 +22,8 @@ class Scheduler(SchedulerBase): with tenant: with self._lock(tenant) as lock_acquired: if not lock_acquired: + self.logger.debug("Could not acquire lock, skipping scheduling") return - self._run() + count = self._run() + self.logger.info(f"Sent {count} scheduled tasks") + sleep(Conf().scheduler_interval) diff --git a/authentik/tasks/setup.py b/authentik/tasks/setup.py index 26b4387ff2..919575726e 100644 --- a/authentik/tasks/setup.py +++ b/authentik/tasks/setup.py @@ -1,4 +1,6 @@ -import authentik.lib.setup # noqa +from authentik.root.setup import setup + +setup() import django # noqa: E402 diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py index de958e83fc..982cb2ce72 100644 --- a/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py @@ -83,7 +83,7 @@ class PostgresBroker(Broker): @cached_property def model(self) -> type[TaskBase]: - return import_string(Conf().task_class) + return import_string(Conf().task_model) @property def query_set(self) -> QuerySet: @@ -231,7 +231,7 @@ class _PostgresConsumer(Consumer): # Override because dramatiq doesn't allow us setting this manually # TODO: turn it into a setting - self.timeout = 30000 // 1000 + self.timeout = Conf().worker["consumer_listen_timeout"] @property def connection(self) -> DatabaseWrapper: diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/conf.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/conf.py index 515ce82b86..40d47ae2ee 100644 --- a/packages/django-dramatiq-postgres/django_dramatiq_postgres/conf.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/conf.py @@ -10,8 +10,8 @@ class Conf: _ = settings.DRAMATIQ except AttributeError as exc: raise ImproperlyConfigured("Setting DRAMATIQ not set.") from exc - if "task_class" not in self.conf: - raise ImproperlyConfigured("DRAMATIQ.task_class not defined") + if "task_model" not in self.conf: + raise ImproperlyConfigured("DRAMATIQ.task_model not defined") @property def conf(self) -> dict[str, Any]: @@ -53,8 +53,8 @@ class Conf: return self.conf.get("channel_prefix", "dramatiq") @property - def task_class(self) -> str: - return self.conf["task_class"] + def task_model(self) -> str: + return self.conf["task_model"] @property def autodiscovery(self) -> dict[str, Any]: @@ -81,9 +81,22 @@ class Conf: "watch_use_polling": False, "processes": None, "threads": None, + "consumer_listen_timeout": 30, **self.conf.get("worker", {}), } + @property + def scheduler_class(self) -> str: + return self.conf.get("scheduler_class", "django_dramatiq_postgres.scheduler.Scheduler") + + @property + def schedule_model(self) -> str | None: + return self.conf.get("schedule_model") + + @property + def scheduler_interval(self) -> int: + return self.conf.get("scheduler_interval", 60) + @property def test(self) -> bool: return self.conf.get("test", False) diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py index 79175934e0..9ce10525e4 100644 --- a/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py @@ -1,10 +1,13 @@ import contextvars +from threading import Event from typing import Any +from django.core.exceptions import ImproperlyConfigured from django.db import ( close_old_connections, connections, ) +from django.utils.module_loading import import_string from dramatiq.actor import Actor from dramatiq.broker import Broker from dramatiq.logging import get_logger @@ -13,6 +16,7 @@ from dramatiq.middleware.middleware import Middleware from django_dramatiq_postgres.conf import Conf from django_dramatiq_postgres.models import TaskBase +from django_dramatiq_postgres.scheduler import Scheduler class DbConnectionMiddleware(Middleware): @@ -74,5 +78,39 @@ class CurrentTask(Middleware): self.logger.warning("Task was None, not saving. This should not happen.") return else: - tasks[-1].save() + task = tasks[-1] + fields_to_exclude = { + "message_id", + "queue_name", + "actor_name", + "message", + "state", + "mtime", + "result", + "result_expiry", + } + fields_to_update = [ + f.name + for f in task._meta.get_fields() + if f.name not in fields_to_exclude and not f.auto_created and f.column + ] + if fields_to_update: + tasks[-1].save(update_fields=fields_to_update) self._TASKS.set(tasks[:-1]) + + +class SchedulerMiddleware(Middleware): + def __init__(self): + self.logger = get_logger(__name__, type(self)) + + if not Conf().schedule_model: + raise ImproperlyConfigured( + "When using the scheduler, DRAMATIQ.schedule_class must be set." + ) + + self.scheduler_stop_event = Event() + self.scheduler: Scheduler = import_string(Conf().scheduler_class)(self.scheduler_stop_event) + + def after_process_boot(self, broker: Broker): + self.scheduler.broker = broker + self.scheduler.start() diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/scheduler.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/scheduler.py index 56c0db78ad..8c63939aa1 100644 --- a/packages/django-dramatiq-postgres/django_dramatiq_postgres/scheduler.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/scheduler.py @@ -1,3 +1,5 @@ +from threading import Event, Thread +from time import sleep import pglock from django.db import router, transaction from django.db.models import QuerySet @@ -11,14 +13,17 @@ from django_dramatiq_postgres.conf import Conf from django_dramatiq_postgres.models import ScheduleBase -class Scheduler: - def __init__(self, broker: Broker): +class Scheduler(Thread): + broker: Broker + + def __init__(self, stop_event: Event, *args, **kwargs): + super().__init__(*args, **kwargs) + self.stop_event = stop_event self.logger = get_logger(__name__, type(self)) - self.broker = broker @cached_property def model(self) -> type[ScheduleBase]: - return import_string(Conf().task_class) + return import_string(Conf().schedule_model) @property def query_set(self) -> QuerySet: @@ -36,15 +41,22 @@ class Scheduler: timeout=0, ) - def _run(self): + def _run(self) -> int: + count = 0 with transaction.atomic(using=router.db_for_write(self.model)): for schedule in self.query_set.select_for_update().filter( next_run__lt=now(), ): self.process_schedule(schedule) + count += 1 + return count def run(self): - with self._lock() as lock_acquired: - if not lock_acquired: - return - self._run() + while not self.stop_event.is_set(): + with self._lock() as lock_acquired: + if not lock_acquired: + self.logger.debug("Could not acquire lock, skipping scheduling") + return + count = self._run() + self.logger.info(f"Sent {count} scheduled tasks") + sleep(Conf().scheduler_interval) diff --git a/scripts/generate_config.py b/scripts/generate_config.py index 255622d439..556db23ae5 100755 --- a/scripts/generate_config.py +++ b/scripts/generate_config.py @@ -49,6 +49,7 @@ def generate_local_config(): "worker": { "processes": 1, "threads": 1, + "consumer_listen_timeout": 10, }, }