@ -37,7 +37,7 @@ class VersionSerializer(PassiveSerializer):
|
||||
"""Get latest version from cache"""
|
||||
version_in_cache = cache.get(VERSION_CACHE_KEY)
|
||||
if not version_in_cache: # pragma: no cover
|
||||
update_latest_version.delay()
|
||||
update_latest_version.send()
|
||||
return __version__
|
||||
return version_in_cache
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""test admin tasks"""
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.test import TestCase
|
||||
from requests_mock import Mocker
|
||||
|
||||
from authentik.admin.tasks import (
|
||||
@ -10,7 +11,6 @@ from authentik.admin.tasks import (
|
||||
)
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.tasks.tests import TaskTestCase
|
||||
|
||||
RESPONSE_VALID = {
|
||||
"$schema": "https://version.goauthentik.io/schema.json",
|
||||
@ -23,7 +23,7 @@ RESPONSE_VALID = {
|
||||
}
|
||||
|
||||
|
||||
class TestAdminTasks(TaskTestCase):
|
||||
class TestAdminTasks(TestCase):
|
||||
"""test admin tasks"""
|
||||
|
||||
def test_version_valid_response(self):
|
||||
@ -31,7 +31,6 @@ class TestAdminTasks(TaskTestCase):
|
||||
with Mocker() as mocker, CONFIG.patch("disable_update_check", False):
|
||||
mocker.get("https://version.goauthentik.io/version.json", json=RESPONSE_VALID)
|
||||
update_latest_version.send()
|
||||
self.tasks_join(update_latest_version.queue_name)
|
||||
self.assertEqual(cache.get(VERSION_CACHE_KEY), "99999999.9999999")
|
||||
self.assertTrue(
|
||||
Event.objects.filter(
|
||||
@ -42,7 +41,6 @@ class TestAdminTasks(TaskTestCase):
|
||||
)
|
||||
# test that a consecutive check doesn't create a duplicate event
|
||||
update_latest_version.send()
|
||||
self.tasks_join(update_latest_version.queue_name)
|
||||
self.assertEqual(
|
||||
len(
|
||||
Event.objects.filter(
|
||||
@ -59,7 +57,6 @@ class TestAdminTasks(TaskTestCase):
|
||||
with Mocker() as mocker:
|
||||
mocker.get("https://version.goauthentik.io/version.json", status_code=400)
|
||||
update_latest_version.send()
|
||||
self.tasks_join(update_latest_version.queue_name)
|
||||
self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0")
|
||||
self.assertFalse(
|
||||
Event.objects.filter(
|
||||
@ -71,7 +68,6 @@ class TestAdminTasks(TaskTestCase):
|
||||
"""Test Update checker while its disabled"""
|
||||
with CONFIG.patch("disable_update_check", True):
|
||||
update_latest_version.send()
|
||||
self.tasks_join(update_latest_version.queue_name)
|
||||
self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0")
|
||||
|
||||
def test_clear_update_notifications(self):
|
||||
@ -82,7 +78,6 @@ class TestAdminTasks(TaskTestCase):
|
||||
Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={"new_version": "1.1.1"})
|
||||
Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={})
|
||||
clear_update_notifications.send()
|
||||
self.tasks_join(clear_update_notifications.queue_name)
|
||||
self.assertFalse(
|
||||
Event.objects.filter(
|
||||
action=EventAction.UPDATE_AVAILABLE, context__new_version="1.1"
|
||||
|
@ -1,5 +1,6 @@
|
||||
import dramatiq
|
||||
from dramatiq.encoder import PickleEncoder
|
||||
from django.conf import settings
|
||||
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
|
||||
@ -12,10 +13,11 @@ class AuthentikTasksConfig(ManagedAppConfig):
|
||||
|
||||
def ready(self) -> None:
|
||||
from authentik.tasks.broker import PostgresBroker
|
||||
from authentik.tasks.test import TestBroker
|
||||
from authentik.tasks.middleware import CurrentTask, FullyQualifiedActorName
|
||||
|
||||
dramatiq.set_encoder(PickleEncoder())
|
||||
broker = PostgresBroker()
|
||||
broker = PostgresBroker() if not settings.TEST else TestBroker()
|
||||
broker.add_middleware(FullyQualifiedActorName())
|
||||
# broker.add_middleware(Prometheus())
|
||||
broker.add_middleware(CurrentTask())
|
||||
|
@ -2,9 +2,10 @@ import functools
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from queue import Empty, Queue
|
||||
from queue import Empty, PriorityQueue, Queue
|
||||
from random import randint
|
||||
|
||||
from dramatiq.worker import Worker, has_results_middleware, _ConsumerThread, _WorkerThread
|
||||
import tenacity
|
||||
from django.db import (
|
||||
DEFAULT_DB_ALIAS,
|
||||
@ -19,7 +20,7 @@ from django.db.models import QuerySet
|
||||
from django.utils import timezone
|
||||
from dramatiq.broker import Broker, Consumer, MessageProxy
|
||||
from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
|
||||
from dramatiq.errors import ConnectionError, QueueJoinTimeout
|
||||
from dramatiq.errors import ConnectionError, QueueJoinTimeout, RateLimitExceeded, Retry
|
||||
from dramatiq.message import Message
|
||||
from dramatiq.middleware import (
|
||||
AgeLimit,
|
||||
@ -29,6 +30,7 @@ from dramatiq.middleware import (
|
||||
Prometheus,
|
||||
Retries,
|
||||
ShutdownNotifications,
|
||||
SkipMessage,
|
||||
TimeLimit,
|
||||
default_middleware,
|
||||
)
|
||||
@ -191,7 +193,7 @@ class PostgresBroker(Broker):
|
||||
**query,
|
||||
**defaults,
|
||||
}
|
||||
obj, created = self.query_set.update_or_create(
|
||||
self.query_set.update_or_create(
|
||||
**query,
|
||||
defaults=defaults,
|
||||
create_defaults=create_defaults,
|
||||
@ -317,6 +319,7 @@ class _PostgresConsumer(Consumer):
|
||||
state=TaskState.QUEUED,
|
||||
)
|
||||
# We don't care about locks, requeue occurs on worker stop
|
||||
# TODO: this is not true, we need to handle them
|
||||
|
||||
def _fetch_pending_notifies(self) -> list[Notify]:
|
||||
self.logger.debug(f"Polling for lost messages in {self.queue_name}")
|
||||
|
49
authentik/tasks/test.py
Normal file
49
authentik/tasks/test.py
Normal file
@ -0,0 +1,49 @@
|
||||
from dramatiq.worker import Worker, _ConsumerThread, _WorkerThread
|
||||
from dramatiq.broker import Broker, MessageProxy
|
||||
from queue import PriorityQueue
|
||||
|
||||
from authentik.tasks.broker import PostgresBroker
|
||||
|
||||
|
||||
class TestWorker(Worker):
|
||||
def __init__(self, queue_name: str, broker: Broker):
|
||||
super().__init__(broker=broker)
|
||||
self.work_queue = PriorityQueue()
|
||||
self.consumers = {
|
||||
queue_name: _ConsumerThread(
|
||||
broker=self.broker,
|
||||
queue_name=queue_name,
|
||||
prefetch=2,
|
||||
work_queue=None,
|
||||
worker_timeout=1,
|
||||
),
|
||||
}
|
||||
self.consumers[queue_name].consumer = self.broker.consume(
|
||||
queue_name=queue_name,
|
||||
prefetch=2,
|
||||
timeout=1,
|
||||
)
|
||||
self._worker = _WorkerThread(
|
||||
broker=self.broker,
|
||||
consumers=self.consumers,
|
||||
work_queue=None,
|
||||
worker_timeout=1,
|
||||
)
|
||||
|
||||
self.broker.emit_before("worker_boot", self)
|
||||
self.broker.emit_after("worker_boot", self)
|
||||
self.broker.emit_before("worker_thread_boot", self)
|
||||
self.broker.emit_after("worker_thread_boot", self)
|
||||
|
||||
def process_message(self, message: MessageProxy):
|
||||
self.logger.error(f"processing message {message}")
|
||||
self.work_queue.put(message)
|
||||
self._worker.process_message(message)
|
||||
|
||||
|
||||
class TestBroker(PostgresBroker):
|
||||
def enqueue(self, *args, **kwargs):
|
||||
message = super().enqueue(*args, **kwargs)
|
||||
worker = TestWorker(message.queue_name, broker=self)
|
||||
worker.process_message(MessageProxy(message))
|
||||
return message
|
@ -1,26 +0,0 @@
|
||||
from django.test import TestCase
|
||||
from dramatiq import Worker, get_broker
|
||||
|
||||
|
||||
class TaskTestCase(TestCase):
|
||||
def _pre_setup(self):
|
||||
super()._pre_setup()
|
||||
|
||||
self.broker = get_broker()
|
||||
self.broker.flush_all()
|
||||
|
||||
self.worker = Worker(self.broker, worker_timeout=100)
|
||||
self.worker.start()
|
||||
|
||||
def _post_teardown(self):
|
||||
self.worker.stop()
|
||||
|
||||
super()._post_teardown()
|
||||
|
||||
def tasks_join(self, queue_name: str | None = None):
|
||||
if queue_name is None:
|
||||
for queue in self.broker.get_declared_queues():
|
||||
self.broker.join(queue)
|
||||
else:
|
||||
self.broker.join(queue_name)
|
||||
self.worker.join()
|
Reference in New Issue
Block a user