Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-03-27 20:00:32 +01:00
parent 9465dafd7d
commit 991778b2be
6 changed files with 61 additions and 38 deletions

View File

@ -37,7 +37,7 @@ class VersionSerializer(PassiveSerializer):
"""Get latest version from cache"""
version_in_cache = cache.get(VERSION_CACHE_KEY)
if not version_in_cache: # pragma: no cover
update_latest_version.delay()
update_latest_version.send()
return __version__
return version_in_cache

View File

@ -1,6 +1,7 @@
"""test admin tasks"""
from django.core.cache import cache
from django.test import TestCase
from requests_mock import Mocker
from authentik.admin.tasks import (
@ -10,7 +11,6 @@ from authentik.admin.tasks import (
)
from authentik.events.models import Event, EventAction
from authentik.lib.config import CONFIG
from authentik.tasks.tests import TaskTestCase
RESPONSE_VALID = {
"$schema": "https://version.goauthentik.io/schema.json",
@ -23,7 +23,7 @@ RESPONSE_VALID = {
}
class TestAdminTasks(TaskTestCase):
class TestAdminTasks(TestCase):
"""test admin tasks"""
def test_version_valid_response(self):
@ -31,7 +31,6 @@ class TestAdminTasks(TaskTestCase):
with Mocker() as mocker, CONFIG.patch("disable_update_check", False):
mocker.get("https://version.goauthentik.io/version.json", json=RESPONSE_VALID)
update_latest_version.send()
self.tasks_join(update_latest_version.queue_name)
self.assertEqual(cache.get(VERSION_CACHE_KEY), "99999999.9999999")
self.assertTrue(
Event.objects.filter(
@ -42,7 +41,6 @@ class TestAdminTasks(TaskTestCase):
)
# test that a consecutive check doesn't create a duplicate event
update_latest_version.send()
self.tasks_join(update_latest_version.queue_name)
self.assertEqual(
len(
Event.objects.filter(
@ -59,7 +57,6 @@ class TestAdminTasks(TaskTestCase):
with Mocker() as mocker:
mocker.get("https://version.goauthentik.io/version.json", status_code=400)
update_latest_version.send()
self.tasks_join(update_latest_version.queue_name)
self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0")
self.assertFalse(
Event.objects.filter(
@ -71,7 +68,6 @@ class TestAdminTasks(TaskTestCase):
"""Test Update checker while its disabled"""
with CONFIG.patch("disable_update_check", True):
update_latest_version.send()
self.tasks_join(update_latest_version.queue_name)
self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0")
def test_clear_update_notifications(self):
@ -82,7 +78,6 @@ class TestAdminTasks(TaskTestCase):
Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={"new_version": "1.1.1"})
Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={})
clear_update_notifications.send()
self.tasks_join(clear_update_notifications.queue_name)
self.assertFalse(
Event.objects.filter(
action=EventAction.UPDATE_AVAILABLE, context__new_version="1.1"

View File

@ -1,5 +1,6 @@
import dramatiq
from dramatiq.encoder import PickleEncoder
from django.conf import settings
from authentik.blueprints.apps import ManagedAppConfig
@ -12,10 +13,11 @@ class AuthentikTasksConfig(ManagedAppConfig):
def ready(self) -> None:
from authentik.tasks.broker import PostgresBroker
from authentik.tasks.test import TestBroker
from authentik.tasks.middleware import CurrentTask, FullyQualifiedActorName
dramatiq.set_encoder(PickleEncoder())
broker = PostgresBroker()
broker = PostgresBroker() if not settings.TEST else TestBroker()
broker.add_middleware(FullyQualifiedActorName())
# broker.add_middleware(Prometheus())
broker.add_middleware(CurrentTask())

View File

@ -2,9 +2,10 @@ import functools
import logging
import time
from collections.abc import Iterable
from queue import Empty, Queue
from queue import Empty, PriorityQueue, Queue
from random import randint
from dramatiq.worker import Worker, has_results_middleware, _ConsumerThread, _WorkerThread
import tenacity
from django.db import (
DEFAULT_DB_ALIAS,
@ -19,7 +20,7 @@ from django.db.models import QuerySet
from django.utils import timezone
from dramatiq.broker import Broker, Consumer, MessageProxy
from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
from dramatiq.errors import ConnectionError, QueueJoinTimeout
from dramatiq.errors import ConnectionError, QueueJoinTimeout, RateLimitExceeded, Retry
from dramatiq.message import Message
from dramatiq.middleware import (
AgeLimit,
@ -29,6 +30,7 @@ from dramatiq.middleware import (
Prometheus,
Retries,
ShutdownNotifications,
SkipMessage,
TimeLimit,
default_middleware,
)
@ -191,7 +193,7 @@ class PostgresBroker(Broker):
**query,
**defaults,
}
obj, created = self.query_set.update_or_create(
self.query_set.update_or_create(
**query,
defaults=defaults,
create_defaults=create_defaults,
@ -317,6 +319,7 @@ class _PostgresConsumer(Consumer):
state=TaskState.QUEUED,
)
# We don't care about locks, requeue occurs on worker stop
# TODO: this is not true, we need to handle them
def _fetch_pending_notifies(self) -> list[Notify]:
self.logger.debug(f"Polling for lost messages in {self.queue_name}")

49
authentik/tasks/test.py Normal file
View File

@ -0,0 +1,49 @@
from dramatiq.worker import Worker, _ConsumerThread, _WorkerThread
from dramatiq.broker import Broker, MessageProxy
from queue import PriorityQueue
from authentik.tasks.broker import PostgresBroker
class TestWorker(Worker):
def __init__(self, queue_name: str, broker: Broker):
super().__init__(broker=broker)
self.work_queue = PriorityQueue()
self.consumers = {
queue_name: _ConsumerThread(
broker=self.broker,
queue_name=queue_name,
prefetch=2,
work_queue=None,
worker_timeout=1,
),
}
self.consumers[queue_name].consumer = self.broker.consume(
queue_name=queue_name,
prefetch=2,
timeout=1,
)
self._worker = _WorkerThread(
broker=self.broker,
consumers=self.consumers,
work_queue=None,
worker_timeout=1,
)
self.broker.emit_before("worker_boot", self)
self.broker.emit_after("worker_boot", self)
self.broker.emit_before("worker_thread_boot", self)
self.broker.emit_after("worker_thread_boot", self)
def process_message(self, message: MessageProxy):
self.logger.error(f"processing message {message}")
self.work_queue.put(message)
self._worker.process_message(message)
class TestBroker(PostgresBroker):
def enqueue(self, *args, **kwargs):
message = super().enqueue(*args, **kwargs)
worker = TestWorker(message.queue_name, broker=self)
worker.process_message(MessageProxy(message))
return message

View File

@ -1,26 +0,0 @@
from django.test import TestCase
from dramatiq import Worker, get_broker
class TaskTestCase(TestCase):
def _pre_setup(self):
super()._pre_setup()
self.broker = get_broker()
self.broker.flush_all()
self.worker = Worker(self.broker, worker_timeout=100)
self.worker.start()
def _post_teardown(self):
self.worker.stop()
super()._post_teardown()
def tasks_join(self, queue_name: str | None = None):
if queue_name is None:
for queue in self.broker.get_declared_queues():
self.broker.join(queue)
else:
self.broker.join(queue_name)
self.worker.join()