move metricsmiddleware to package

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-30 09:38:40 +02:00
parent 6b3fbb0abf
commit c94fa13826
4 changed files with 180 additions and 147 deletions

View File

@ -4,6 +4,7 @@ import importlib
from collections import OrderedDict
from hashlib import sha512
from pathlib import Path
from tempfile import gettempdir
import orjson
from sentry_sdk import set_tag
@ -411,7 +412,13 @@ DRAMATIQ = {
("authentik.tasks.middleware.LoggingMiddleware", {}),
("authentik.tasks.middleware.DescriptionMiddleware", {}),
("authentik.tasks.middleware.WorkerStatusMiddleware", {}),
("authentik.tasks.middleware.MetricsMiddleware", {}),
(
"authentik.tasks.middleware.MetricsMiddleware",
{
"multiproc_dir": str(Path(gettempdir()) / "authentik_prometheus_tmp"),
"prefix": "authentik",
},
),
),
"test": TEST,
}

View File

@ -1,3 +1,12 @@
from signal import pause
from structlog.stdlib import get_logger
from authentik.lib.config import CONFIG
LOGGER = get_logger()
def worker_status():
import authentik.tasks.setup # noqa
from authentik.tasks.middleware import WorkerStatusMiddleware
@ -9,4 +18,11 @@ def worker_metrics():
import authentik.tasks.setup # noqa
from authentik.tasks.middleware import MetricsMiddleware
MetricsMiddleware.run()
addr, _, port = CONFIG.get("listen.listen_metrics").rpartition(":")
try:
port = int(port)
MetricsMiddleware.run(addr, port)
except ValueError:
LOGGER.error(f"Invalid port entered: {port}")
pause()

View File

@ -14,6 +14,7 @@ from dramatiq.common import current_millis
from dramatiq.message import Message
from dramatiq.middleware import Middleware
from structlog.stdlib import get_logger
from django_dramatiq_postgres.middleware import MetricsMiddleware as BaseMetricsMiddleware
from authentik import get_full_version
from authentik.events.models import Event, EventAction
@ -169,153 +170,9 @@ class WorkerStatusMiddleware(Middleware):
sleep(30)
class MetricsMiddleware(Middleware):
def __init__(self):
super().__init__()
self.delayed_messages = set()
self.message_start_times = {}
_tmp = Path(gettempdir())
prometheus_tmp_dir = str(_tmp.joinpath("authentik_prometheus_tmp"))
os.makedirs(prometheus_tmp_dir, exist_ok=True)
os.environ.setdefault("PROMETHEUS_MULTIPROC_DIR", prometheus_tmp_dir)
class MetricsMiddleware(BaseMetricsMiddleware):
@property
def forks(self):
from authentik.tasks.forks import worker_metrics
return [worker_metrics]
def before_worker_boot(self, broker: Broker, worker):
if settings.TEST:
return
from prometheus_client import Counter, Gauge, Histogram
self.total_messages = Counter(
"authentik_tasks_total",
"The total number of tasks processed.",
["queue_name", "actor_name"],
)
self.total_errored_messages = Counter(
"authentik_tasks_errors_total",
"The total number of errored tasks.",
["queue_name", "actor_name"],
)
self.total_retried_messages = Counter(
"authentik_tasks_retries_total",
"The total number of retried tasks.",
["queue_name", "actor_name"],
)
self.total_rejected_messages = Counter(
"authentik_tasks_rejected_total",
"The total number of dead-lettered tasks.",
["queue_name", "actor_name"],
)
self.inprogress_messages = Gauge(
"authentik_tasks_inprogress",
"The number of tasks in progress.",
["queue_name", "actor_name"],
multiprocess_mode="livesum",
)
self.inprogress_delayed_messages = Gauge(
"authentik_tasks_delayed_inprogress",
"The number of delayed tasks in memory.",
["queue_name", "actor_name"],
)
self.messages_durations = Histogram(
"authentik_tasks_duration_miliseconds",
"The time spent processing tasks.",
["queue_name", "actor_name"],
buckets=(
5,
10,
25,
50,
75,
100,
250,
500,
750,
1_000,
2_500,
5_000,
7_500,
10_000,
30_000,
60_000,
600_000,
900_000,
1_800_000,
3_600_000,
float("inf"),
),
)
def after_worker_shutdown(self, broker: Broker, worker):
from prometheus_client import multiprocess
# TODO: worker_id
multiprocess.mark_process_dead(os.getpid())
def _make_labels(self, message: Message) -> tuple[str, str]:
return (message.queue_name, message.actor_name)
def after_nack(self, broker: Broker, message: Message):
self.total_rejected_messages.labels(*self._make_labels(message)).inc()
def after_enqueue(self, broker: Broker, message: Message, delay: int):
if "retries" in message.options:
self.total_retried_messages.labels(*self._make_labels(message)).inc()
def before_delay_message(self, broker: Broker, message: Message):
self.delayed_messages.add(message.message_id)
self.inprogress_delayed_messages.labels(*self._make_labels(message)).inc()
def before_process_message(self, broker: Broker, message: Message):
labels = self._make_labels(message)
if message.message_id in self.delayed_messages:
self.delayed_messages.remove(message.message_id)
self.inprogress_delayed_messages.labels(*labels).dec()
self.inprogress_messages.labels(*labels).inc()
self.message_start_times[message.message_id] = current_millis()
def after_process_message(
self,
broker: Broker,
message: Message,
*,
result: Any | None = None,
exception: Exception | None = None,
):
labels = self._make_labels(message)
message_start_time = self.message_start_times.pop(message.message_id, current_millis())
message_duration = current_millis() - message_start_time
self.messages_durations.labels(*labels).observe(message_duration)
self.inprogress_messages.labels(*labels).dec()
self.total_messages.labels(*labels).inc()
if exception is not None:
self.total_errored_messages.labels(*labels).inc()
after_skip_message = after_process_message
@staticmethod
def run():
from prometheus_client import CollectorRegistry, multiprocess, start_http_server
addr, _, port = CONFIG.get("listen.listen_metrics").rpartition(":")
try:
port = int(port)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
start_http_server(port, addr, registry)
except ValueError:
LOGGER.error(f"Invalid port entered: {port}")
except OSError:
LOGGER.warning("Port is already in use, not starting metrics server")
pause()

View File

@ -1,4 +1,6 @@
import contextvars
import os
from signal import pause
from typing import Any
from django.db import (
@ -7,6 +9,7 @@ from django.db import (
)
from dramatiq.actor import Actor
from dramatiq.broker import Broker
from dramatiq.common import current_millis
from dramatiq.logging import get_logger
from dramatiq.message import Message
from dramatiq.middleware.middleware import Middleware
@ -101,3 +104,153 @@ class CurrentTask(Middleware):
def after_skip_message(self, broker: Broker, message: Message):
self.after_process_message(broker, message)
class MetricsMiddleware(Middleware):
def __init__(
self,
prefix: str,
multiproc_dir: str,
labels: list[str] | None = None,
):
super().__init__()
self.prefix = prefix
self.labels: list[str] = labels if labels is not None else ["queue_name", "actor_name"]
self.delayed_messages = set()
self.message_start_times = {}
os.makedirs(multiproc_dir, exist_ok=True)
os.environ.setdefault("PROMETHEUS_MULTIPROC_DIR", multiproc_dir)
def before_worker_boot(self, broker: Broker, worker):
if Conf().test:
return
from prometheus_client import Counter, Gauge, Histogram
self.total_messages = Counter(
f"{self.prefix}_tasks_total",
"The total number of tasks processed.",
self.labels,
)
self.total_errored_messages = Counter(
f"{self.prefix}_tasks_errors_total",
"The total number of errored tasks.",
self.labels,
)
self.total_retried_messages = Counter(
f"{self.prefix}_tasks_retries_total",
"The total number of retried tasks.",
self.labels,
)
self.total_rejected_messages = Counter(
f"{self.prefix}_tasks_rejected_total",
"The total number of dead-lettered tasks.",
self.labels,
)
self.inprogress_messages = Gauge(
f"{self.prefix}_tasks_inprogress",
"The number of tasks in progress.",
self.labels,
multiprocess_mode="livesum",
)
self.inprogress_delayed_messages = Gauge(
f"{self.prefix}_tasks_delayed_inprogress",
"The number of delayed tasks in memory.",
self.labels,
)
self.messages_durations = Histogram(
f"{self.prefix}_tasks_duration_miliseconds",
"The time spent processing tasks.",
self.labels,
buckets=(
5,
10,
25,
50,
75,
100,
250,
500,
750,
1_000,
2_500,
5_000,
7_500,
10_000,
30_000,
60_000,
600_000,
900_000,
1_800_000,
3_600_000,
float("inf"),
),
)
def after_worker_shutdown(self, broker: Broker, worker):
from prometheus_client import multiprocess
# TODO: worker_id
multiprocess.mark_process_dead(os.getpid())
def _make_labels(self, message: Message) -> list[str]:
return [message.queue_name, message.actor_name]
def after_nack(self, broker: Broker, message: Message):
self.total_rejected_messages.labels(*self._make_labels(message)).inc()
def after_enqueue(self, broker: Broker, message: Message, delay: int):
if "retries" in message.options:
self.total_retried_messages.labels(*self._make_labels(message)).inc()
def before_delay_message(self, broker: Broker, message: Message):
self.delayed_messages.add(message.message_id)
self.inprogress_delayed_messages.labels(*self._make_labels(message)).inc()
def before_process_message(self, broker: Broker, message: Message):
labels = self._make_labels(message)
if message.message_id in self.delayed_messages:
self.delayed_messages.remove(message.message_id)
self.inprogress_delayed_messages.labels(*labels).dec()
self.inprogress_messages.labels(*labels).inc()
self.message_start_times[message.message_id] = current_millis()
def after_process_message(
self,
broker: Broker,
message: Message,
*,
result: Any | None = None,
exception: Exception | None = None,
):
labels = self._make_labels(message)
message_start_time = self.message_start_times.pop(message.message_id, current_millis())
message_duration = current_millis() - message_start_time
self.messages_durations.labels(*labels).observe(message_duration)
self.inprogress_messages.labels(*labels).dec()
self.total_messages.labels(*labels).inc()
if exception is not None:
self.total_errored_messages.labels(*labels).inc()
after_skip_message = after_process_message
@staticmethod
def run(addr: str, port: int):
from prometheus_client import CollectorRegistry, multiprocess, start_http_server
try:
port = int(port)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
start_http_server(port, addr, registry)
except OSError:
get_logger(__name__, type(MetricsMiddleware)).warning(
"Port is already in use, not starting metrics server"
)
pause()