move metricsmiddleware to package

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-30 09:38:40 +02:00
parent 6b3fbb0abf
commit c94fa13826
4 changed files with 180 additions and 147 deletions

View File

@ -4,6 +4,7 @@ import importlib
from collections import OrderedDict from collections import OrderedDict
from hashlib import sha512 from hashlib import sha512
from pathlib import Path from pathlib import Path
from tempfile import gettempdir
import orjson import orjson
from sentry_sdk import set_tag from sentry_sdk import set_tag
@ -411,7 +412,13 @@ DRAMATIQ = {
("authentik.tasks.middleware.LoggingMiddleware", {}), ("authentik.tasks.middleware.LoggingMiddleware", {}),
("authentik.tasks.middleware.DescriptionMiddleware", {}), ("authentik.tasks.middleware.DescriptionMiddleware", {}),
("authentik.tasks.middleware.WorkerStatusMiddleware", {}), ("authentik.tasks.middleware.WorkerStatusMiddleware", {}),
("authentik.tasks.middleware.MetricsMiddleware", {}), (
"authentik.tasks.middleware.MetricsMiddleware",
{
"multiproc_dir": str(Path(gettempdir()) / "authentik_prometheus_tmp"),
"prefix": "authentik",
},
),
), ),
"test": TEST, "test": TEST,
} }

View File

@ -1,3 +1,12 @@
from signal import pause
from structlog.stdlib import get_logger
from authentik.lib.config import CONFIG
LOGGER = get_logger()
def worker_status(): def worker_status():
import authentik.tasks.setup # noqa import authentik.tasks.setup # noqa
from authentik.tasks.middleware import WorkerStatusMiddleware from authentik.tasks.middleware import WorkerStatusMiddleware
@ -9,4 +18,11 @@ def worker_metrics():
import authentik.tasks.setup # noqa import authentik.tasks.setup # noqa
from authentik.tasks.middleware import MetricsMiddleware from authentik.tasks.middleware import MetricsMiddleware
MetricsMiddleware.run() addr, _, port = CONFIG.get("listen.listen_metrics").rpartition(":")
try:
port = int(port)
MetricsMiddleware.run(addr, port)
except ValueError:
LOGGER.error(f"Invalid port entered: {port}")
pause()

View File

@ -14,6 +14,7 @@ from dramatiq.common import current_millis
from dramatiq.message import Message from dramatiq.message import Message
from dramatiq.middleware import Middleware from dramatiq.middleware import Middleware
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from django_dramatiq_postgres.middleware import MetricsMiddleware as BaseMetricsMiddleware
from authentik import get_full_version from authentik import get_full_version
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
@ -169,153 +170,9 @@ class WorkerStatusMiddleware(Middleware):
sleep(30) sleep(30)
class MetricsMiddleware(Middleware): class MetricsMiddleware(BaseMetricsMiddleware):
def __init__(self):
super().__init__()
self.delayed_messages = set()
self.message_start_times = {}
_tmp = Path(gettempdir())
prometheus_tmp_dir = str(_tmp.joinpath("authentik_prometheus_tmp"))
os.makedirs(prometheus_tmp_dir, exist_ok=True)
os.environ.setdefault("PROMETHEUS_MULTIPROC_DIR", prometheus_tmp_dir)
@property @property
def forks(self): def forks(self):
from authentik.tasks.forks import worker_metrics from authentik.tasks.forks import worker_metrics
return [worker_metrics] return [worker_metrics]
def before_worker_boot(self, broker: Broker, worker):
if settings.TEST:
return
from prometheus_client import Counter, Gauge, Histogram
self.total_messages = Counter(
"authentik_tasks_total",
"The total number of tasks processed.",
["queue_name", "actor_name"],
)
self.total_errored_messages = Counter(
"authentik_tasks_errors_total",
"The total number of errored tasks.",
["queue_name", "actor_name"],
)
self.total_retried_messages = Counter(
"authentik_tasks_retries_total",
"The total number of retried tasks.",
["queue_name", "actor_name"],
)
self.total_rejected_messages = Counter(
"authentik_tasks_rejected_total",
"The total number of dead-lettered tasks.",
["queue_name", "actor_name"],
)
self.inprogress_messages = Gauge(
"authentik_tasks_inprogress",
"The number of tasks in progress.",
["queue_name", "actor_name"],
multiprocess_mode="livesum",
)
self.inprogress_delayed_messages = Gauge(
"authentik_tasks_delayed_inprogress",
"The number of delayed tasks in memory.",
["queue_name", "actor_name"],
)
self.messages_durations = Histogram(
"authentik_tasks_duration_miliseconds",
"The time spent processing tasks.",
["queue_name", "actor_name"],
buckets=(
5,
10,
25,
50,
75,
100,
250,
500,
750,
1_000,
2_500,
5_000,
7_500,
10_000,
30_000,
60_000,
600_000,
900_000,
1_800_000,
3_600_000,
float("inf"),
),
)
def after_worker_shutdown(self, broker: Broker, worker):
from prometheus_client import multiprocess
# TODO: worker_id
multiprocess.mark_process_dead(os.getpid())
def _make_labels(self, message: Message) -> tuple[str, str]:
return (message.queue_name, message.actor_name)
def after_nack(self, broker: Broker, message: Message):
self.total_rejected_messages.labels(*self._make_labels(message)).inc()
def after_enqueue(self, broker: Broker, message: Message, delay: int):
if "retries" in message.options:
self.total_retried_messages.labels(*self._make_labels(message)).inc()
def before_delay_message(self, broker: Broker, message: Message):
self.delayed_messages.add(message.message_id)
self.inprogress_delayed_messages.labels(*self._make_labels(message)).inc()
def before_process_message(self, broker: Broker, message: Message):
labels = self._make_labels(message)
if message.message_id in self.delayed_messages:
self.delayed_messages.remove(message.message_id)
self.inprogress_delayed_messages.labels(*labels).dec()
self.inprogress_messages.labels(*labels).inc()
self.message_start_times[message.message_id] = current_millis()
def after_process_message(
self,
broker: Broker,
message: Message,
*,
result: Any | None = None,
exception: Exception | None = None,
):
labels = self._make_labels(message)
message_start_time = self.message_start_times.pop(message.message_id, current_millis())
message_duration = current_millis() - message_start_time
self.messages_durations.labels(*labels).observe(message_duration)
self.inprogress_messages.labels(*labels).dec()
self.total_messages.labels(*labels).inc()
if exception is not None:
self.total_errored_messages.labels(*labels).inc()
after_skip_message = after_process_message
@staticmethod
def run():
from prometheus_client import CollectorRegistry, multiprocess, start_http_server
addr, _, port = CONFIG.get("listen.listen_metrics").rpartition(":")
try:
port = int(port)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
start_http_server(port, addr, registry)
except ValueError:
LOGGER.error(f"Invalid port entered: {port}")
except OSError:
LOGGER.warning("Port is already in use, not starting metrics server")
pause()

View File

@ -1,4 +1,6 @@
import contextvars import contextvars
import os
from signal import pause
from typing import Any from typing import Any
from django.db import ( from django.db import (
@ -7,6 +9,7 @@ from django.db import (
) )
from dramatiq.actor import Actor from dramatiq.actor import Actor
from dramatiq.broker import Broker from dramatiq.broker import Broker
from dramatiq.common import current_millis
from dramatiq.logging import get_logger from dramatiq.logging import get_logger
from dramatiq.message import Message from dramatiq.message import Message
from dramatiq.middleware.middleware import Middleware from dramatiq.middleware.middleware import Middleware
@ -101,3 +104,153 @@ class CurrentTask(Middleware):
def after_skip_message(self, broker: Broker, message: Message): def after_skip_message(self, broker: Broker, message: Message):
self.after_process_message(broker, message) self.after_process_message(broker, message)
class MetricsMiddleware(Middleware):
def __init__(
self,
prefix: str,
multiproc_dir: str,
labels: list[str] | None = None,
):
super().__init__()
self.prefix = prefix
self.labels: list[str] = labels if labels is not None else ["queue_name", "actor_name"]
self.delayed_messages = set()
self.message_start_times = {}
os.makedirs(multiproc_dir, exist_ok=True)
os.environ.setdefault("PROMETHEUS_MULTIPROC_DIR", multiproc_dir)
def before_worker_boot(self, broker: Broker, worker):
if Conf().test:
return
from prometheus_client import Counter, Gauge, Histogram
self.total_messages = Counter(
f"{self.prefix}_tasks_total",
"The total number of tasks processed.",
self.labels,
)
self.total_errored_messages = Counter(
f"{self.prefix}_tasks_errors_total",
"The total number of errored tasks.",
self.labels,
)
self.total_retried_messages = Counter(
f"{self.prefix}_tasks_retries_total",
"The total number of retried tasks.",
self.labels,
)
self.total_rejected_messages = Counter(
f"{self.prefix}_tasks_rejected_total",
"The total number of dead-lettered tasks.",
self.labels,
)
self.inprogress_messages = Gauge(
f"{self.prefix}_tasks_inprogress",
"The number of tasks in progress.",
self.labels,
multiprocess_mode="livesum",
)
self.inprogress_delayed_messages = Gauge(
f"{self.prefix}_tasks_delayed_inprogress",
"The number of delayed tasks in memory.",
self.labels,
)
self.messages_durations = Histogram(
f"{self.prefix}_tasks_duration_miliseconds",
"The time spent processing tasks.",
self.labels,
buckets=(
5,
10,
25,
50,
75,
100,
250,
500,
750,
1_000,
2_500,
5_000,
7_500,
10_000,
30_000,
60_000,
600_000,
900_000,
1_800_000,
3_600_000,
float("inf"),
),
)
def after_worker_shutdown(self, broker: Broker, worker):
from prometheus_client import multiprocess
# TODO: worker_id
multiprocess.mark_process_dead(os.getpid())
def _make_labels(self, message: Message) -> list[str]:
return [message.queue_name, message.actor_name]
def after_nack(self, broker: Broker, message: Message):
self.total_rejected_messages.labels(*self._make_labels(message)).inc()
def after_enqueue(self, broker: Broker, message: Message, delay: int):
if "retries" in message.options:
self.total_retried_messages.labels(*self._make_labels(message)).inc()
def before_delay_message(self, broker: Broker, message: Message):
self.delayed_messages.add(message.message_id)
self.inprogress_delayed_messages.labels(*self._make_labels(message)).inc()
def before_process_message(self, broker: Broker, message: Message):
labels = self._make_labels(message)
if message.message_id in self.delayed_messages:
self.delayed_messages.remove(message.message_id)
self.inprogress_delayed_messages.labels(*labels).dec()
self.inprogress_messages.labels(*labels).inc()
self.message_start_times[message.message_id] = current_millis()
def after_process_message(
self,
broker: Broker,
message: Message,
*,
result: Any | None = None,
exception: Exception | None = None,
):
labels = self._make_labels(message)
message_start_time = self.message_start_times.pop(message.message_id, current_millis())
message_duration = current_millis() - message_start_time
self.messages_durations.labels(*labels).observe(message_duration)
self.inprogress_messages.labels(*labels).dec()
self.total_messages.labels(*labels).inc()
if exception is not None:
self.total_errored_messages.labels(*labels).inc()
after_skip_message = after_process_message
@staticmethod
def run(addr: str, port: int):
from prometheus_client import CollectorRegistry, multiprocess, start_http_server
try:
port = int(port)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
start_http_server(port, addr, registry)
except OSError:
get_logger(__name__, type(MetricsMiddleware)).warning(
"Port is already in use, not starting metrics server"
)
pause()