diff --git a/authentik/root/settings.py b/authentik/root/settings.py index c6e9cea8fb..cde656534e 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -4,6 +4,7 @@ import importlib from collections import OrderedDict from hashlib import sha512 from pathlib import Path +from tempfile import gettempdir import orjson from sentry_sdk import set_tag @@ -411,7 +412,13 @@ DRAMATIQ = { ("authentik.tasks.middleware.LoggingMiddleware", {}), ("authentik.tasks.middleware.DescriptionMiddleware", {}), ("authentik.tasks.middleware.WorkerStatusMiddleware", {}), - ("authentik.tasks.middleware.MetricsMiddleware", {}), + ( + "authentik.tasks.middleware.MetricsMiddleware", + { + "multiproc_dir": str(Path(gettempdir()) / "authentik_prometheus_tmp"), + "prefix": "authentik", + }, + ), ), "test": TEST, } diff --git a/authentik/tasks/forks.py b/authentik/tasks/forks.py index 40591a755d..457f3852dc 100644 --- a/authentik/tasks/forks.py +++ b/authentik/tasks/forks.py @@ -1,3 +1,12 @@ +from signal import pause + +from structlog.stdlib import get_logger + +from authentik.lib.config import CONFIG + +LOGGER = get_logger() + + def worker_status(): import authentik.tasks.setup # noqa from authentik.tasks.middleware import WorkerStatusMiddleware @@ -9,4 +18,11 @@ def worker_metrics(): import authentik.tasks.setup # noqa from authentik.tasks.middleware import MetricsMiddleware - MetricsMiddleware.run() + addr, _, port = CONFIG.get("listen.listen_metrics").rpartition(":") + + try: + port = int(port) + MetricsMiddleware.run(addr, port) + except ValueError: + LOGGER.error(f"Invalid port entered: {port}") + pause() diff --git a/authentik/tasks/middleware.py b/authentik/tasks/middleware.py index 251f3dee75..8f129385e5 100644 --- a/authentik/tasks/middleware.py +++ b/authentik/tasks/middleware.py @@ -14,6 +14,7 @@ from dramatiq.common import current_millis from dramatiq.message import Message from dramatiq.middleware import Middleware from structlog.stdlib import get_logger +from django_dramatiq_postgres.middleware import MetricsMiddleware as BaseMetricsMiddleware from authentik import get_full_version from authentik.events.models import Event, EventAction @@ -169,153 +170,9 @@ class WorkerStatusMiddleware(Middleware): sleep(30) -class MetricsMiddleware(Middleware): - def __init__(self): - super().__init__() - self.delayed_messages = set() - self.message_start_times = {} - - _tmp = Path(gettempdir()) - prometheus_tmp_dir = str(_tmp.joinpath("authentik_prometheus_tmp")) - os.makedirs(prometheus_tmp_dir, exist_ok=True) - os.environ.setdefault("PROMETHEUS_MULTIPROC_DIR", prometheus_tmp_dir) - +class MetricsMiddleware(BaseMetricsMiddleware): @property def forks(self): from authentik.tasks.forks import worker_metrics return [worker_metrics] - - def before_worker_boot(self, broker: Broker, worker): - if settings.TEST: - return - - from prometheus_client import Counter, Gauge, Histogram - - self.total_messages = Counter( - "authentik_tasks_total", - "The total number of tasks processed.", - ["queue_name", "actor_name"], - ) - self.total_errored_messages = Counter( - "authentik_tasks_errors_total", - "The total number of errored tasks.", - ["queue_name", "actor_name"], - ) - self.total_retried_messages = Counter( - "authentik_tasks_retries_total", - "The total number of retried tasks.", - ["queue_name", "actor_name"], - ) - self.total_rejected_messages = Counter( - "authentik_tasks_rejected_total", - "The total number of dead-lettered tasks.", - ["queue_name", "actor_name"], - ) - self.inprogress_messages = Gauge( - "authentik_tasks_inprogress", - "The number of tasks in progress.", - ["queue_name", "actor_name"], - multiprocess_mode="livesum", - ) - self.inprogress_delayed_messages = Gauge( - "authentik_tasks_delayed_inprogress", - "The number of delayed tasks in memory.", - ["queue_name", "actor_name"], - ) - self.messages_durations = Histogram( - "authentik_tasks_duration_miliseconds", - "The time spent processing tasks.", - ["queue_name", "actor_name"], - buckets=( - 5, - 10, - 25, - 50, - 75, - 100, - 250, - 500, - 750, - 1_000, - 2_500, - 5_000, - 7_500, - 10_000, - 30_000, - 60_000, - 600_000, - 900_000, - 1_800_000, - 3_600_000, - float("inf"), - ), - ) - - def after_worker_shutdown(self, broker: Broker, worker): - from prometheus_client import multiprocess - - # TODO: worker_id - multiprocess.mark_process_dead(os.getpid()) - - def _make_labels(self, message: Message) -> tuple[str, str]: - return (message.queue_name, message.actor_name) - - def after_nack(self, broker: Broker, message: Message): - self.total_rejected_messages.labels(*self._make_labels(message)).inc() - - def after_enqueue(self, broker: Broker, message: Message, delay: int): - if "retries" in message.options: - self.total_retried_messages.labels(*self._make_labels(message)).inc() - - def before_delay_message(self, broker: Broker, message: Message): - self.delayed_messages.add(message.message_id) - self.inprogress_delayed_messages.labels(*self._make_labels(message)).inc() - - def before_process_message(self, broker: Broker, message: Message): - labels = self._make_labels(message) - if message.message_id in self.delayed_messages: - self.delayed_messages.remove(message.message_id) - self.inprogress_delayed_messages.labels(*labels).dec() - - self.inprogress_messages.labels(*labels).inc() - self.message_start_times[message.message_id] = current_millis() - - def after_process_message( - self, - broker: Broker, - message: Message, - *, - result: Any | None = None, - exception: Exception | None = None, - ): - labels = self._make_labels(message) - - message_start_time = self.message_start_times.pop(message.message_id, current_millis()) - message_duration = current_millis() - message_start_time - self.messages_durations.labels(*labels).observe(message_duration) - - self.inprogress_messages.labels(*labels).dec() - self.total_messages.labels(*labels).inc() - if exception is not None: - self.total_errored_messages.labels(*labels).inc() - - after_skip_message = after_process_message - - @staticmethod - def run(): - from prometheus_client import CollectorRegistry, multiprocess, start_http_server - - addr, _, port = CONFIG.get("listen.listen_metrics").rpartition(":") - - try: - port = int(port) - - registry = CollectorRegistry() - multiprocess.MultiProcessCollector(registry) - start_http_server(port, addr, registry) - except ValueError: - LOGGER.error(f"Invalid port entered: {port}") - except OSError: - LOGGER.warning("Port is already in use, not starting metrics server") - pause() diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py index 3f762ce780..ef23b74f81 100644 --- a/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py @@ -1,4 +1,6 @@ import contextvars +import os +from signal import pause from typing import Any from django.db import ( @@ -7,6 +9,7 @@ from django.db import ( ) from dramatiq.actor import Actor from dramatiq.broker import Broker +from dramatiq.common import current_millis from dramatiq.logging import get_logger from dramatiq.message import Message from dramatiq.middleware.middleware import Middleware @@ -101,3 +104,153 @@ class CurrentTask(Middleware): def after_skip_message(self, broker: Broker, message: Message): self.after_process_message(broker, message) + + +class MetricsMiddleware(Middleware): + def __init__( + self, + prefix: str, + multiproc_dir: str, + labels: list[str] | None = None, + ): + super().__init__() + self.prefix = prefix + self.labels: list[str] = labels if labels is not None else ["queue_name", "actor_name"] + + self.delayed_messages = set() + self.message_start_times = {} + + os.makedirs(multiproc_dir, exist_ok=True) + os.environ.setdefault("PROMETHEUS_MULTIPROC_DIR", multiproc_dir) + + def before_worker_boot(self, broker: Broker, worker): + if Conf().test: + return + + from prometheus_client import Counter, Gauge, Histogram + + self.total_messages = Counter( + f"{self.prefix}_tasks_total", + "The total number of tasks processed.", + self.labels, + ) + self.total_errored_messages = Counter( + f"{self.prefix}_tasks_errors_total", + "The total number of errored tasks.", + self.labels, + ) + self.total_retried_messages = Counter( + f"{self.prefix}_tasks_retries_total", + "The total number of retried tasks.", + self.labels, + ) + self.total_rejected_messages = Counter( + f"{self.prefix}_tasks_rejected_total", + "The total number of dead-lettered tasks.", + self.labels, + ) + self.inprogress_messages = Gauge( + f"{self.prefix}_tasks_inprogress", + "The number of tasks in progress.", + self.labels, + multiprocess_mode="livesum", + ) + self.inprogress_delayed_messages = Gauge( + f"{self.prefix}_tasks_delayed_inprogress", + "The number of delayed tasks in memory.", + self.labels, + ) + self.messages_durations = Histogram( + f"{self.prefix}_tasks_duration_miliseconds", + "The time spent processing tasks.", + self.labels, + buckets=( + 5, + 10, + 25, + 50, + 75, + 100, + 250, + 500, + 750, + 1_000, + 2_500, + 5_000, + 7_500, + 10_000, + 30_000, + 60_000, + 600_000, + 900_000, + 1_800_000, + 3_600_000, + float("inf"), + ), + ) + + def after_worker_shutdown(self, broker: Broker, worker): + from prometheus_client import multiprocess + + # TODO: worker_id + multiprocess.mark_process_dead(os.getpid()) + + def _make_labels(self, message: Message) -> list[str]: + return [message.queue_name, message.actor_name] + + def after_nack(self, broker: Broker, message: Message): + self.total_rejected_messages.labels(*self._make_labels(message)).inc() + + def after_enqueue(self, broker: Broker, message: Message, delay: int): + if "retries" in message.options: + self.total_retried_messages.labels(*self._make_labels(message)).inc() + + def before_delay_message(self, broker: Broker, message: Message): + self.delayed_messages.add(message.message_id) + self.inprogress_delayed_messages.labels(*self._make_labels(message)).inc() + + def before_process_message(self, broker: Broker, message: Message): + labels = self._make_labels(message) + if message.message_id in self.delayed_messages: + self.delayed_messages.remove(message.message_id) + self.inprogress_delayed_messages.labels(*labels).dec() + + self.inprogress_messages.labels(*labels).inc() + self.message_start_times[message.message_id] = current_millis() + + def after_process_message( + self, + broker: Broker, + message: Message, + *, + result: Any | None = None, + exception: Exception | None = None, + ): + labels = self._make_labels(message) + + message_start_time = self.message_start_times.pop(message.message_id, current_millis()) + message_duration = current_millis() - message_start_time + self.messages_durations.labels(*labels).observe(message_duration) + + self.inprogress_messages.labels(*labels).dec() + self.total_messages.labels(*labels).inc() + if exception is not None: + self.total_errored_messages.labels(*labels).inc() + + after_skip_message = after_process_message + + @staticmethod + def run(addr: str, port: int): + from prometheus_client import CollectorRegistry, multiprocess, start_http_server + + try: + port = int(port) + + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + start_http_server(port, addr, registry) + except OSError: + get_logger(__name__, type(MetricsMiddleware)).warning( + "Port is already in use, not starting metrics server" + ) + pause()