From 991778b2be42e3b687c26c87e9f3ae310ad6ed5e Mon Sep 17 00:00:00 2001 From: Marc 'risson' Schmitt Date: Thu, 27 Mar 2025 20:00:32 +0100 Subject: [PATCH] wip Signed-off-by: Marc 'risson' Schmitt --- authentik/admin/api/version.py | 2 +- authentik/admin/tests/test_tasks.py | 9 ++---- authentik/tasks/apps.py | 4 ++- authentik/tasks/broker.py | 9 ++++-- authentik/tasks/test.py | 49 +++++++++++++++++++++++++++++ authentik/tasks/tests.py | 26 --------------- 6 files changed, 61 insertions(+), 38 deletions(-) create mode 100644 authentik/tasks/test.py delete mode 100644 authentik/tasks/tests.py diff --git a/authentik/admin/api/version.py b/authentik/admin/api/version.py index 72ddfa9eee..4d913294b2 100644 --- a/authentik/admin/api/version.py +++ b/authentik/admin/api/version.py @@ -37,7 +37,7 @@ class VersionSerializer(PassiveSerializer): """Get latest version from cache""" version_in_cache = cache.get(VERSION_CACHE_KEY) if not version_in_cache: # pragma: no cover - update_latest_version.delay() + update_latest_version.send() return __version__ return version_in_cache diff --git a/authentik/admin/tests/test_tasks.py b/authentik/admin/tests/test_tasks.py index 1ca53680cd..126e9fbb8d 100644 --- a/authentik/admin/tests/test_tasks.py +++ b/authentik/admin/tests/test_tasks.py @@ -1,6 +1,7 @@ """test admin tasks""" from django.core.cache import cache +from django.test import TestCase from requests_mock import Mocker from authentik.admin.tasks import ( @@ -10,7 +11,6 @@ from authentik.admin.tasks import ( ) from authentik.events.models import Event, EventAction from authentik.lib.config import CONFIG -from authentik.tasks.tests import TaskTestCase RESPONSE_VALID = { "$schema": "https://version.goauthentik.io/schema.json", @@ -23,7 +23,7 @@ RESPONSE_VALID = { } -class TestAdminTasks(TaskTestCase): +class TestAdminTasks(TestCase): """test admin tasks""" def test_version_valid_response(self): @@ -31,7 +31,6 @@ class TestAdminTasks(TaskTestCase): with Mocker() as mocker, CONFIG.patch("disable_update_check", False): mocker.get("https://version.goauthentik.io/version.json", json=RESPONSE_VALID) update_latest_version.send() - self.tasks_join(update_latest_version.queue_name) self.assertEqual(cache.get(VERSION_CACHE_KEY), "99999999.9999999") self.assertTrue( Event.objects.filter( @@ -42,7 +41,6 @@ class TestAdminTasks(TaskTestCase): ) # test that a consecutive check doesn't create a duplicate event update_latest_version.send() - self.tasks_join(update_latest_version.queue_name) self.assertEqual( len( Event.objects.filter( @@ -59,7 +57,6 @@ class TestAdminTasks(TaskTestCase): with Mocker() as mocker: mocker.get("https://version.goauthentik.io/version.json", status_code=400) update_latest_version.send() - self.tasks_join(update_latest_version.queue_name) self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") self.assertFalse( Event.objects.filter( @@ -71,7 +68,6 @@ class TestAdminTasks(TaskTestCase): """Test Update checker while its disabled""" with CONFIG.patch("disable_update_check", True): update_latest_version.send() - self.tasks_join(update_latest_version.queue_name) self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") def test_clear_update_notifications(self): @@ -82,7 +78,6 @@ class TestAdminTasks(TaskTestCase): Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={"new_version": "1.1.1"}) Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={}) clear_update_notifications.send() - self.tasks_join(clear_update_notifications.queue_name) self.assertFalse( Event.objects.filter( action=EventAction.UPDATE_AVAILABLE, context__new_version="1.1" diff --git a/authentik/tasks/apps.py b/authentik/tasks/apps.py index 96eb430cdd..28849902e5 100644 --- a/authentik/tasks/apps.py +++ b/authentik/tasks/apps.py @@ -1,5 +1,6 @@ import dramatiq from dramatiq.encoder import PickleEncoder +from django.conf import settings from authentik.blueprints.apps import ManagedAppConfig @@ -12,10 +13,11 @@ class AuthentikTasksConfig(ManagedAppConfig): def ready(self) -> None: from authentik.tasks.broker import PostgresBroker + from authentik.tasks.test import TestBroker from authentik.tasks.middleware import CurrentTask, FullyQualifiedActorName dramatiq.set_encoder(PickleEncoder()) - broker = PostgresBroker() + broker = PostgresBroker() if not settings.TEST else TestBroker() broker.add_middleware(FullyQualifiedActorName()) # broker.add_middleware(Prometheus()) broker.add_middleware(CurrentTask()) diff --git a/authentik/tasks/broker.py b/authentik/tasks/broker.py index 5eb4b89de0..bcbad433bc 100644 --- a/authentik/tasks/broker.py +++ b/authentik/tasks/broker.py @@ -2,9 +2,10 @@ import functools import logging import time from collections.abc import Iterable -from queue import Empty, Queue +from queue import Empty, PriorityQueue, Queue from random import randint +from dramatiq.worker import Worker, has_results_middleware, _ConsumerThread, _WorkerThread import tenacity from django.db import ( DEFAULT_DB_ALIAS, @@ -19,7 +20,7 @@ from django.db.models import QuerySet from django.utils import timezone from dramatiq.broker import Broker, Consumer, MessageProxy from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name -from dramatiq.errors import ConnectionError, QueueJoinTimeout +from dramatiq.errors import ConnectionError, QueueJoinTimeout, RateLimitExceeded, Retry from dramatiq.message import Message from dramatiq.middleware import ( AgeLimit, @@ -29,6 +30,7 @@ from dramatiq.middleware import ( Prometheus, Retries, ShutdownNotifications, + SkipMessage, TimeLimit, default_middleware, ) @@ -191,7 +193,7 @@ class PostgresBroker(Broker): **query, **defaults, } - obj, created = self.query_set.update_or_create( + self.query_set.update_or_create( **query, defaults=defaults, create_defaults=create_defaults, @@ -317,6 +319,7 @@ class _PostgresConsumer(Consumer): state=TaskState.QUEUED, ) # We don't care about locks, requeue occurs on worker stop + # TODO: this is not true, we need to handle them def _fetch_pending_notifies(self) -> list[Notify]: self.logger.debug(f"Polling for lost messages in {self.queue_name}") diff --git a/authentik/tasks/test.py b/authentik/tasks/test.py new file mode 100644 index 0000000000..e8100e03ea --- /dev/null +++ b/authentik/tasks/test.py @@ -0,0 +1,49 @@ +from dramatiq.worker import Worker, _ConsumerThread, _WorkerThread +from dramatiq.broker import Broker, MessageProxy +from queue import PriorityQueue + +from authentik.tasks.broker import PostgresBroker + + +class TestWorker(Worker): + def __init__(self, queue_name: str, broker: Broker): + super().__init__(broker=broker) + self.work_queue = PriorityQueue() + self.consumers = { + queue_name: _ConsumerThread( + broker=self.broker, + queue_name=queue_name, + prefetch=2, + work_queue=None, + worker_timeout=1, + ), + } + self.consumers[queue_name].consumer = self.broker.consume( + queue_name=queue_name, + prefetch=2, + timeout=1, + ) + self._worker = _WorkerThread( + broker=self.broker, + consumers=self.consumers, + work_queue=None, + worker_timeout=1, + ) + + self.broker.emit_before("worker_boot", self) + self.broker.emit_after("worker_boot", self) + self.broker.emit_before("worker_thread_boot", self) + self.broker.emit_after("worker_thread_boot", self) + + def process_message(self, message: MessageProxy): + self.logger.error(f"processing message {message}") + self.work_queue.put(message) + self._worker.process_message(message) + + +class TestBroker(PostgresBroker): + def enqueue(self, *args, **kwargs): + message = super().enqueue(*args, **kwargs) + worker = TestWorker(message.queue_name, broker=self) + worker.process_message(MessageProxy(message)) + return message diff --git a/authentik/tasks/tests.py b/authentik/tasks/tests.py deleted file mode 100644 index e202c877e0..0000000000 --- a/authentik/tasks/tests.py +++ /dev/null @@ -1,26 +0,0 @@ -from django.test import TestCase -from dramatiq import Worker, get_broker - - -class TaskTestCase(TestCase): - def _pre_setup(self): - super()._pre_setup() - - self.broker = get_broker() - self.broker.flush_all() - - self.worker = Worker(self.broker, worker_timeout=100) - self.worker.start() - - def _post_teardown(self): - self.worker.stop() - - super()._post_teardown() - - def tasks_join(self, queue_name: str | None = None): - if queue_name is None: - for queue in self.broker.get_declared_queues(): - self.broker.join(queue) - else: - self.broker.join(queue_name) - self.worker.join()