move middlewares to package

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2025-06-19 14:57:17 +02:00
parent 8980282a02
commit 5a5176e21f
20 changed files with 141 additions and 116 deletions

View File

@ -2,6 +2,7 @@
from django.core.cache import cache from django.core.cache import cache
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq import actor from dramatiq import actor
from packaging.version import parse from packaging.version import parse
from requests import RequestException from requests import RequestException
@ -12,7 +13,7 @@ from authentik.admin.apps import PROM_INFO
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.lib.utils.http import get_http_session from authentik.lib.utils.http import get_http_session
from authentik.tasks.middleware import CurrentTask from authentik.tasks.models import Task
LOGGER = get_logger() LOGGER = get_logger()
VERSION_NULL = "0.0.0" VERSION_NULL = "0.0.0"
@ -35,7 +36,7 @@ def _set_prom_info():
@actor @actor
def update_latest_version(): def update_latest_version():
"""Update latest version info""" """Update latest version info"""
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
if CONFIG.get_bool("disable_update_check"): if CONFIG.get_bool("disable_update_check"):
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
self.info("Version check disabled.") self.info("Version check disabled.")

View File

@ -10,6 +10,7 @@ from dacite.core import from_dict
from django.db import DatabaseError, InternalError, ProgrammingError from django.db import DatabaseError, InternalError, ProgrammingError
from django.utils.text import slugify from django.utils.text import slugify
from django.utils.timezone import now from django.utils.timezone import now
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor from dramatiq.actor import actor
from dramatiq.middleware import Middleware from dramatiq.middleware import Middleware
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -35,7 +36,7 @@ from authentik.blueprints.v1.oci import OCI_PREFIX
from authentik.events.logs import capture_logs from authentik.events.logs import capture_logs
from authentik.events.utils import sanitize_dict from authentik.events.utils import sanitize_dict
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.tasks.middleware import CurrentTask from authentik.tasks.models import Task
from authentik.tasks.schedules.models import Schedule from authentik.tasks.schedules.models import Schedule
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
@ -147,7 +148,7 @@ def blueprints_find() -> list[BlueprintFile]:
@actor(throws=(DatabaseError, ProgrammingError, InternalError)) @actor(throws=(DatabaseError, ProgrammingError, InternalError))
def blueprints_discovery(path: str | None = None): def blueprints_discovery(path: str | None = None):
"""Find blueprints and check if they need to be created in the database""" """Find blueprints and check if they need to be created in the database"""
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
count = 0 count = 0
for blueprint in blueprints_find(): for blueprint in blueprints_find():
if path and blueprint.path != path: if path and blueprint.path != path:
@ -187,7 +188,7 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
@actor @actor
def apply_blueprint(instance_pk: UUID): def apply_blueprint(instance_pk: UUID):
"""Apply single blueprint""" """Apply single blueprint"""
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
self.set_uid(str(instance_pk)) self.set_uid(str(instance_pk))
instance: BlueprintInstance | None = None instance: BlueprintInstance | None = None
try: try:

View File

@ -3,6 +3,7 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from django.utils.timezone import now from django.utils.timezone import now
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor from dramatiq.actor import actor
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -12,7 +13,7 @@ from authentik.core.models import (
ExpiringModel, ExpiringModel,
User, User,
) )
from authentik.tasks.middleware import CurrentTask from authentik.tasks.models import Task
LOGGER = get_logger() LOGGER = get_logger()
@ -20,7 +21,7 @@ LOGGER = get_logger()
@actor @actor
def clean_expired_models(): def clean_expired_models():
"""Remove expired objects""" """Remove expired objects"""
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
for cls in ExpiringModel.__subclasses__(): for cls in ExpiringModel.__subclasses__():
cls: ExpiringModel cls: ExpiringModel
objects = ( objects = (
@ -36,7 +37,7 @@ def clean_expired_models():
@actor @actor
def clean_temporary_users(): def clean_temporary_users():
"""Remove temporary users created by SAML Sources""" """Remove temporary users created by SAML Sources"""
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
_now = datetime.now() _now = datetime.now()
deleted_users = 0 deleted_users = 0
for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}): for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}):

View File

@ -6,12 +6,13 @@ from pathlib import Path
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_pem_private_key from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.x509.base import load_pem_x509_certificate from cryptography.x509.base import load_pem_x509_certificate
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor from dramatiq.actor import actor
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.tasks.middleware import CurrentTask from authentik.tasks.models import Task
LOGGER = get_logger() LOGGER = get_logger()
@ -37,7 +38,7 @@ def ensure_certificate_valid(body: str):
@actor @actor
def certificate_discovery(): def certificate_discovery():
"""Discover, import and update certificates from the filesystem""" """Discover, import and update certificates from the filesystem"""
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
certs = {} certs = {}
private_keys = {} private_keys = {}
discovered = 0 discovered = 0

View File

@ -1,4 +1,5 @@
from django.db.models.aggregates import Count from django.db.models.aggregates import Count
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor from dramatiq.actor import actor
from structlog import get_logger from structlog import get_logger
@ -6,7 +7,7 @@ from authentik.enterprise.policies.unique_password.models import (
UniquePasswordPolicy, UniquePasswordPolicy,
UserPasswordHistory, UserPasswordHistory,
) )
from authentik.tasks.middleware import CurrentTask from authentik.tasks.models import Task
LOGGER = get_logger() LOGGER = get_logger()
@ -16,7 +17,7 @@ def check_and_purge_password_history():
"""Check if any UniquePasswordPolicy exists, and if not, purge the password history table. """Check if any UniquePasswordPolicy exists, and if not, purge the password history table.
This is run on a schedule instead of being triggered by policy binding deletion. This is run on a schedule instead of being triggered by policy binding deletion.
""" """
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
if not UniquePasswordPolicy.objects.exists(): if not UniquePasswordPolicy.objects.exists():
UserPasswordHistory.objects.all().delete() UserPasswordHistory.objects.all().delete()
@ -29,7 +30,7 @@ def check_and_purge_password_history():
@actor @actor
def trim_password_histories(): def trim_password_histories():
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
"""Removes rows from UserPasswordHistory older than """Removes rows from UserPasswordHistory older than
the `n` most recent entries. the `n` most recent entries.

View File

@ -3,6 +3,7 @@ from uuid import UUID
from django.http import HttpRequest from django.http import HttpRequest
from django.utils.timezone import now from django.utils.timezone import now
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor from dramatiq.actor import actor
from requests.exceptions import RequestException from requests.exceptions import RequestException
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -19,7 +20,7 @@ from authentik.events.logs import LogEvent
from authentik.lib.utils.http import get_http_session from authentik.lib.utils.http import get_http_session
from authentik.lib.utils.time import timedelta_from_string from authentik.lib.utils.time import timedelta_from_string
from authentik.policies.engine import PolicyEngine from authentik.policies.engine import PolicyEngine
from authentik.tasks.middleware import CurrentTask from authentik.tasks.models import Task
session = get_http_session() session = get_http_session()
LOGGER = get_logger() LOGGER = get_logger()
@ -62,7 +63,7 @@ def _check_app_access(stream: Stream, event_data: dict) -> bool:
@actor @actor
def _send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]): def _send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]):
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
stream = Stream.objects.filter(pk=stream_uuid).first() stream = Stream.objects.filter(pk=stream_uuid).first()
if not stream: if not stream:

View File

@ -3,6 +3,7 @@
from uuid import UUID from uuid import UUID
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor from dramatiq.actor import actor
from guardian.shortcuts import get_anonymous_user from guardian.shortcuts import get_anonymous_user
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -16,7 +17,7 @@ from authentik.events.models import (
) )
from authentik.policies.engine import PolicyEngine from authentik.policies.engine import PolicyEngine
from authentik.policies.models import PolicyBinding, PolicyEngineMode from authentik.policies.models import PolicyBinding, PolicyEngineMode
from authentik.tasks.middleware import CurrentTask from authentik.tasks.models import Task
LOGGER = get_logger() LOGGER = get_logger()
@ -110,7 +111,7 @@ def gdpr_cleanup(user_pk: int):
@actor @actor
def notification_cleanup(): def notification_cleanup():
"""Cleanup seen notifications and notifications whose event expired.""" """Cleanup seen notifications and notifications whose event expired."""
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
notifications = Notification.objects.filter(Q(event=None) | Q(seen=True)) notifications = Notification.objects.filter(Q(event=None) | Q(seen=True))
amount = notifications.count() amount = notifications.count()
notifications.delete() notifications.delete()

View File

@ -2,6 +2,7 @@ from django.core.paginator import Paginator
from django.db.models import Model, QuerySet from django.db.models import Model, QuerySet
from django.db.models.query import Q from django.db.models.query import Q
from django.utils.text import slugify from django.utils.text import slugify
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import Actor from dramatiq.actor import Actor
from dramatiq.composition import group from dramatiq.composition import group
from dramatiq.errors import Retry from dramatiq.errors import Retry
@ -20,7 +21,6 @@ from authentik.lib.sync.outgoing.exceptions import (
) )
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path, path_to_class from authentik.lib.utils.reflection import class_to_path, path_to_class
from authentik.tasks.middleware import CurrentTask
from authentik.tasks.models import Task from authentik.tasks.models import Task
@ -60,7 +60,7 @@ class SyncTasks:
provider_pk: int, provider_pk: int,
sync_objects: Actor, sync_objects: Actor,
): ):
task = CurrentTask.get_task() task: Task = CurrentTask.get_task()
self.logger = get_logger().bind( self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model), provider_type=class_to_path(self._provider_model),
provider_pk=provider_pk, provider_pk=provider_pk,
@ -114,7 +114,7 @@ class SyncTasks:
override_dry_run=False, override_dry_run=False,
**filter, **filter,
): ):
task = CurrentTask.get_task() task: Task = CurrentTask.get_task()
_object_type: type[Model] = path_to_class(object_type) _object_type: type[Model] = path_to_class(object_type)
self.logger = get_logger().bind( self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model), provider_type=class_to_path(self._provider_model),
@ -186,7 +186,7 @@ class SyncTasks:
provider_pk: int, provider_pk: int,
raw_op: str, raw_op: str,
): ):
task = CurrentTask.get_task() task: Task = CurrentTask.get_task()
self.logger = get_logger().bind( self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model), provider_type=class_to_path(self._provider_model),
) )
@ -234,7 +234,7 @@ class SyncTasks:
action: str, action: str,
pk_set: list[int], pk_set: list[int],
): ):
task = CurrentTask.get_task() task: Task = CurrentTask.get_task()
self.logger = get_logger().bind( self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model), provider_type=class_to_path(self._provider_model),
) )

View File

@ -12,6 +12,7 @@ from channels.layers import get_channel_layer
from django.core.cache import cache from django.core.cache import cache
from django.db.models.base import Model from django.db.models.base import Model
from django.utils.text import slugify from django.utils.text import slugify
from django_dramatiq_postgres.middleware import CurrentTask
from docker.constants import DEFAULT_UNIX_SOCKET from docker.constants import DEFAULT_UNIX_SOCKET
from dramatiq.actor import actor from dramatiq.actor import actor
from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME
@ -42,7 +43,7 @@ from authentik.providers.rac.controllers.docker import RACDockerController
from authentik.providers.rac.controllers.kubernetes import RACKubernetesController from authentik.providers.rac.controllers.kubernetes import RACKubernetesController
from authentik.providers.radius.controllers.docker import RadiusDockerController from authentik.providers.radius.controllers.docker import RadiusDockerController
from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController
from authentik.tasks.middleware import CurrentTask from authentik.tasks.models import Task
LOGGER = get_logger() LOGGER = get_logger()
CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s" CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s"
@ -109,7 +110,7 @@ def outpost_service_connection_monitor(connection_pk: Any):
@actor @actor
def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False): def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False):
"""Create/update/monitor/delete the deployment of an Outpost""" """Create/update/monitor/delete the deployment of an Outpost"""
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
self.set_uid(outpost_pk) self.set_uid(outpost_pk)
logs = [] logs = []
if from_cache: if from_cache:
@ -144,7 +145,7 @@ def outpost_token_ensurer():
""" """
Periodically ensure that all Outposts have valid Service Accounts and Tokens Periodically ensure that all Outposts have valid Service Accounts and Tokens
""" """
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
all_outposts = Outpost.objects.all() all_outposts = Outpost.objects.all()
for outpost in all_outposts: for outpost in all_outposts:
_ = outpost.token _ = outpost.token
@ -227,7 +228,7 @@ def _outpost_single_update(outpost: Outpost, layer=None):
@actor @actor
def outpost_connection_discovery(): def outpost_connection_discovery():
"""Checks the local environment and create Service connections.""" """Checks the local environment and create Service connections."""
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
if not CONFIG.get_bool("outposts.discover"): if not CONFIG.get_bool("outposts.discover"):
self.info("Outpost integration discovery is disabled") self.info("Outpost integration discovery is disabled")
return return

View File

@ -367,6 +367,7 @@ DRAMATIQ = {
"threads": CONFIG.get_int("worker.threads", 1), "threads": CONFIG.get_int("worker.threads", 1),
}, },
"middlewares": ( "middlewares": (
("django_dramatiq_postgres.middleware.FullyQualifiedActorName", {}),
# TODO: fixme # TODO: fixme
# ("dramatiq.middleware.prometheus.Prometheus", {}), # ("dramatiq.middleware.prometheus.Prometheus", {}),
("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}), ("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}),
@ -384,10 +385,9 @@ DRAMATIQ = {
{"max_retries": 20 if not TEST else 0}, {"max_retries": 20 if not TEST else 0},
), ),
# TODO: results # TODO: results
("authentik.tasks.middleware.FullyQualifiedActorName", {}),
("authentik.tasks.middleware.RelObjMiddleware", {}),
("authentik.tasks.middleware.TenantMiddleware", {}), ("authentik.tasks.middleware.TenantMiddleware", {}),
("authentik.tasks.middleware.CurrentTask", {}), ("authentik.tasks.middleware.RelObjMiddleware", {}),
("django_dramatiq_postgres.middleware.CurrentTask", {}),
), ),
"test": TEST, "test": TEST,
} }

View File

@ -1,6 +1,7 @@
"""Kerberos Sync tasks""" """Kerberos Sync tasks"""
from django.core.cache import cache from django.core.cache import cache
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor from dramatiq.actor import actor
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -9,7 +10,7 @@ from authentik.lib.sync.outgoing.exceptions import StopSync
from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.errors import exception_to_string
from authentik.sources.kerberos.models import KerberosSource from authentik.sources.kerberos.models import KerberosSource
from authentik.sources.kerberos.sync import KerberosSync from authentik.sources.kerberos.sync import KerberosSync
from authentik.tasks.middleware import CurrentTask from authentik.tasks.models import Task
LOGGER = get_logger() LOGGER = get_logger()
CACHE_KEY_STATUS = "goauthentik.io/sources/kerberos/status/" CACHE_KEY_STATUS = "goauthentik.io/sources/kerberos/status/"
@ -30,7 +31,7 @@ def kerberos_connectivity_check(pk: str):
@actor(time_limit=(60 * 60 * CONFIG.get_int("sources.kerberos.task_timeout_hours")) * 2.5 * 1000) @actor(time_limit=(60 * 60 * CONFIG.get_int("sources.kerberos.task_timeout_hours")) * 2.5 * 1000)
def kerberos_sync(pk: str): def kerberos_sync(pk: str):
"""Sync a single source""" """Sync a single source"""
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
source: KerberosSource = KerberosSource.objects.filter(enabled=True, pk=pk).first() source: KerberosSource = KerberosSource.objects.filter(enabled=True, pk=pk).first()
if not source: if not source:
return return

View File

@ -3,6 +3,7 @@
from uuid import uuid4 from uuid import uuid4
from django.core.cache import cache from django.core.cache import cache
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor from dramatiq.actor import actor
from dramatiq.composition import group from dramatiq.composition import group
from dramatiq.message import Message from dramatiq.message import Message
@ -20,7 +21,7 @@ from authentik.sources.ldap.sync.forward_delete_users import UserLDAPForwardDele
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
from authentik.tasks.middleware import CurrentTask from authentik.tasks.models import Task
LOGGER = get_logger() LOGGER = get_logger()
SYNC_CLASSES = [ SYNC_CLASSES = [
@ -118,7 +119,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
@actor(time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000) @actor(time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000)
def ldap_sync_page(source_pk: str, sync_class: str, page_cache_key: str): def ldap_sync_page(source_pk: str, sync_class: str, page_cache_key: str):
"""Synchronization of an LDAP Source""" """Synchronization of an LDAP Source"""
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
if not source: if not source:
# Because the source couldn't be found, we don't have a UID # Because the source couldn't be found, we don't have a UID

View File

@ -2,13 +2,14 @@
from json import dumps from json import dumps
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor from dramatiq.actor import actor
from requests import RequestException from requests import RequestException
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.lib.utils.http import get_http_session from authentik.lib.utils.http import get_http_session
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import OAuthSource
from authentik.tasks.middleware import CurrentTask from authentik.tasks.models import Task
LOGGER = get_logger() LOGGER = get_logger()
@ -16,7 +17,7 @@ LOGGER = get_logger()
@actor @actor
def update_well_known_jwks(): def update_well_known_jwks():
"""Update OAuth sources' config from well_known, and JWKS info from the configured URL""" """Update OAuth sources' config from well_known, and JWKS info from the configured URL"""
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
session = get_http_session() session = get_http_session()
for source in OAuthSource.objects.all().exclude(oidc_well_known_url=""): for source in OAuthSource.objects.all().exclude(oidc_well_known_url=""):
try: try:

View File

@ -1,5 +1,6 @@
"""Plex tasks""" """Plex tasks"""
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor from dramatiq.actor import actor
from requests import RequestException from requests import RequestException
@ -7,13 +8,13 @@ from authentik.events.models import Event, EventAction
from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.errors import exception_to_string
from authentik.sources.plex.models import PlexSource from authentik.sources.plex.models import PlexSource
from authentik.sources.plex.plex import PlexAuth from authentik.sources.plex.plex import PlexAuth
from authentik.tasks.middleware import CurrentTask from authentik.tasks.models import Task
@actor @actor
def check_plex_token(source_pk: str): def check_plex_token(source_pk: str):
"""Check the validity of a Plex source.""" """Check the validity of a Plex source."""
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
sources = PlexSource.objects.filter(pk=source_pk) sources = PlexSource.objects.filter(pk=source_pk)
if not sources.exists(): if not sources.exists():
return return

View File

@ -6,6 +6,7 @@ from pathlib import Path
from django.core.cache import cache from django.core.cache import cache
from django.db.transaction import atomic from django.db.transaction import atomic
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor from dramatiq.actor import actor
from fido2.mds3 import filter_revoked, parse_blob from fido2.mds3 import filter_revoked, parse_blob
@ -13,7 +14,7 @@ from authentik.stages.authenticator_webauthn.models import (
UNKNOWN_DEVICE_TYPE_AAGUID, UNKNOWN_DEVICE_TYPE_AAGUID,
WebAuthnDeviceType, WebAuthnDeviceType,
) )
from authentik.tasks.middleware import CurrentTask from authentik.tasks.models import Task
CACHE_KEY_MDS_NO = "goauthentik.io/stages/authenticator_webauthn/mds_no" CACHE_KEY_MDS_NO = "goauthentik.io/stages/authenticator_webauthn/mds_no"
AAGUID_BLOB_PATH = Path(__file__).parent / "mds" / "aaguid.json" AAGUID_BLOB_PATH = Path(__file__).parent / "mds" / "aaguid.json"
@ -31,7 +32,7 @@ def mds_ca() -> bytes:
@actor @actor
def webauthn_mds_import(force=False): def webauthn_mds_import(force=False):
"""Background task to import FIDO Alliance MDS blob and AAGUIDs into database""" """Background task to import FIDO Alliance MDS blob and AAGUIDs into database"""
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
with open(MDS_BLOB_PATH, mode="rb") as _raw_blob: with open(MDS_BLOB_PATH, mode="rb") as _raw_blob:
blob = parse_blob(_raw_blob.read(), mds_ca()) blob = parse_blob(_raw_blob.read(), mds_ca())
to_create_update = [ to_create_update = [

View File

@ -6,6 +6,7 @@ from typing import Any
from django.core.mail import EmailMultiAlternatives from django.core.mail import EmailMultiAlternatives
from django.core.mail.utils import DNS_NAME from django.core.mail.utils import DNS_NAME
from django.utils.text import slugify from django.utils.text import slugify
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor from dramatiq.actor import actor
from dramatiq.composition import group from dramatiq.composition import group
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -15,7 +16,7 @@ from authentik.lib.utils.reflection import class_to_path, path_to_class
from authentik.stages.authenticator_email.models import AuthenticatorEmailStage from authentik.stages.authenticator_email.models import AuthenticatorEmailStage
from authentik.stages.email.models import EmailStage from authentik.stages.email.models import EmailStage
from authentik.stages.email.utils import logo_data from authentik.stages.email.utils import logo_data
from authentik.tasks.middleware import CurrentTask from authentik.tasks.models import Task
LOGGER = get_logger() LOGGER = get_logger()
@ -54,7 +55,7 @@ def send_mail(
email_stage_pk: str | None = None, email_stage_pk: str | None = None,
): ):
"""Send Email for Email Stage. Retries are scheduled automatically.""" """Send Email for Email Stage. Retries are scheduled automatically."""
self = CurrentTask.get_task() self: Task = CurrentTask.get_task()
message_id = make_msgid(domain=DNS_NAME) message_id = make_msgid(domain=DNS_NAME)
self.set_uid(slugify(message_id.replace(".", "_").replace("@", "_"))) self.set_uid(slugify(message_id.replace(".", "_").replace("@", "_")))
if not stage_class_path or not email_stage_pk: if not stage_class_path or not email_stage_pk:

View File

@ -1,21 +1,11 @@
from typing import Any from django.db.models import QuerySet
from django_dramatiq_postgres.broker import PostgresBroker from django_dramatiq_postgres.broker import PostgresBroker
from dramatiq.message import Message
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.tenants.utils import get_current_tenant
LOGGER = get_logger() LOGGER = get_logger()
class Broker(PostgresBroker): class Broker(PostgresBroker):
def model_defaults(self, message: Message) -> dict[str, Any]: @property
rel_obj = message.options.get("rel_obj") def query_set(self) -> QuerySet:
if rel_obj: return self.model.objects.select_related("tenant").using(self.db_alias)
del message.options["rel_obj"]
return {
"tenant": get_current_tenant(),
"rel_obj": rel_obj,
**super().model_defaults(message),
}

View File

@ -1,16 +1,22 @@
import contextvars
from typing import Any
from dramatiq.actor import Actor
from dramatiq.broker import Broker from dramatiq.broker import Broker
from dramatiq.message import Message from dramatiq.message import Message
from dramatiq.middleware import Middleware from dramatiq.middleware import Middleware
from structlog.stdlib import get_logger
from authentik.tasks.models import Task from authentik.tasks.models import Task
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
from authentik.tenants.utils import get_current_tenant
LOGGER = get_logger()
class TenantMiddleware(Middleware):
def before_enqueue(self, broker: Broker, message: Message, delay: int):
message.options["model_defaults"]["tenant"] = get_current_tenant()
def before_process_message(self, broker: Broker, message: Message):
task: Task = message.options["task"]
task.tenant.activate()
def after_process_message(self, *args, **kwargs):
Tenant.deactivate()
class RelObjMiddleware(Middleware): class RelObjMiddleware(Middleware):
@ -18,52 +24,7 @@ class RelObjMiddleware(Middleware):
def actor_options(self): def actor_options(self):
return {"rel_obj"} return {"rel_obj"}
def before_enqueue(self, broker: Broker, message: Message, delay: int):
class FullyQualifiedActorName(Middleware): if rel_obj := message.options.get("rel_obj"):
def before_declare_actor(self, broker: Broker, actor: Actor): del message.options["rel_obj"]
actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}" message.options["model_defaults"]["rel_obj"] = rel_obj
class CurrentTask(Middleware):
_TASK: contextvars.ContextVar[list[Task] | None] = contextvars.ContextVar(
"_TASK",
default=None,
)
@classmethod
def get_task(cls) -> Task:
task = cls._TASK.get()
if not task:
raise RuntimeError("CurrentTask.get_task() should only be called in a running task")
return task[-1]
def before_process_message(self, broker: Broker, message: Message):
tasks = self._TASK.get()
if tasks is None:
tasks = []
tasks.append(Task.objects.get(message_id=message.message_id))
self._TASK.set(tasks)
def after_process_message(
self,
broker: Broker,
message: Message,
*,
result: Any | None = None,
exception: Exception | None = None,
):
tasks: list[Task] | None = self._TASK.get()
if tasks is None or len(tasks) == 0:
LOGGER.warn("Task was None, not saving. This should not happen")
return
else:
tasks[-1].save()
self._TASK.set(tasks[:-1])
class TenantMiddleware(Middleware):
def before_process_message(self, broker: Broker, message: Message):
Task.objects.select_related("tenant").get(message_id=message.message_id).tenant.activate()
def after_process_message(self, *args, **kwargs):
Tenant.deactivate()

View File

@ -23,6 +23,7 @@ from django.utils.module_loading import import_string
from dramatiq.broker import Broker, Consumer, MessageProxy from dramatiq.broker import Broker, Consumer, MessageProxy
from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
from dramatiq.errors import ConnectionError, QueueJoinTimeout from dramatiq.errors import ConnectionError, QueueJoinTimeout
from dramatiq.logging import get_logger
from dramatiq.message import Message from dramatiq.message import Message
from dramatiq.middleware import ( from dramatiq.middleware import (
Middleware, Middleware,
@ -30,12 +31,11 @@ from dramatiq.middleware import (
from pglock.core import _cast_lock_id from pglock.core import _cast_lock_id
from psycopg import Notify, sql from psycopg import Notify, sql
from psycopg.errors import AdminShutdown from psycopg.errors import AdminShutdown
from structlog.stdlib import get_logger
from django_dramatiq_postgres.conf import Conf from django_dramatiq_postgres.conf import Conf
from django_dramatiq_postgres.models import CHANNEL_PREFIX, ChannelIdentifier, TaskBase, TaskState from django_dramatiq_postgres.models import CHANNEL_PREFIX, ChannelIdentifier, TaskBase, TaskState
LOGGER = get_logger() logger = get_logger(__name__)
def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str: def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
@ -62,7 +62,7 @@ class PostgresBroker(Broker):
**kwargs, **kwargs,
): ):
super().__init__(*args, middleware=[], **kwargs) super().__init__(*args, middleware=[], **kwargs)
self.logger = get_logger().bind() self.logger = get_logger(__name__, type(self))
self.queues = set() self.queues = set()
@ -131,7 +131,7 @@ class PostgresBroker(Broker):
reraise=True, reraise=True,
wait=tenacity.wait_random_exponential(multiplier=1, max=30), wait=tenacity.wait_random_exponential(multiplier=1, max=30),
stop=tenacity.stop_after_attempt(10), stop=tenacity.stop_after_attempt(10),
before_sleep=tenacity.before_sleep_log(LOGGER, logging.INFO, exc_info=True), before_sleep=tenacity.before_sleep_log(logger, logging.INFO, exc_info=True),
) )
def enqueue(self, message: Message, *, delay: int | None = None) -> Message: def enqueue(self, message: Message, *, delay: int | None = None) -> Message:
canonical_queue_name = message.queue_name canonical_queue_name = message.queue_name
@ -148,20 +148,26 @@ class PostgresBroker(Broker):
self.declare_queue(canonical_queue_name) self.declare_queue(canonical_queue_name)
self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}") self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}")
message.options["model_defaults"] = self.model_defaults(message)
self.emit_before("enqueue", message, delay) self.emit_before("enqueue", message, delay)
query = { query = {
"message_id": message.message_id, "message_id": message.message_id,
} }
defaults = self.model_defaults(message) defaults = message.options["model_defaults"]
del message.options["model_defaults"]
create_defaults = { create_defaults = {
**query, **query,
**defaults, **defaults,
} }
self.query_set.update_or_create( self.query_set.update_or_create(
**query, **query,
defaults=defaults, defaults=defaults,
create_defaults=create_defaults, create_defaults=create_defaults,
) )
self.emit_after("enqueue", message, delay) self.emit_after("enqueue", message, delay)
return message return message
@ -209,7 +215,7 @@ class _PostgresConsumer(Consumer):
timeout: int, timeout: int,
**kwargs, **kwargs,
): ):
self.logger = get_logger().bind() self.logger = get_logger(__name__, type(self))
self.notifies: list[Notify] = [] self.notifies: list[Notify] = []
self.broker = broker self.broker = broker

View File

@ -1,10 +1,18 @@
import contextvars
from typing import Any
from django.db import ( from django.db import (
close_old_connections, close_old_connections,
connections, connections,
) )
from dramatiq.actor import Actor
from dramatiq.broker import Broker
from dramatiq.logging import get_logger
from dramatiq.message import Message
from dramatiq.middleware.middleware import Middleware from dramatiq.middleware.middleware import Middleware
from django_dramatiq_postgres.conf import Conf from django_dramatiq_postgres.conf import Conf
from django_dramatiq_postgres.models import TaskBase
class DbConnectionMiddleware(Middleware): class DbConnectionMiddleware(Middleware):
@ -22,3 +30,49 @@ class DbConnectionMiddleware(Middleware):
before_consumer_thread_shutdown = _close_connections before_consumer_thread_shutdown = _close_connections
before_worker_thread_shutdown = _close_connections before_worker_thread_shutdown = _close_connections
before_worker_shutdown = _close_connections before_worker_shutdown = _close_connections
class FullyQualifiedActorName(Middleware):
def before_declare_actor(self, broker: Broker, actor: Actor):
actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}"
class CurrentTask(Middleware):
def __init__(self):
self.logger = get_logger(__name__, type(self))
# This is a list of tasks, so that in tests, when a task calls another task, this acts as a pile
_TASKS: contextvars.ContextVar[list[TaskBase] | None] = contextvars.ContextVar(
"_TASKS",
default=None,
)
@classmethod
def get_task(cls) -> TaskBase:
task = cls._TASKS.get()
if not task:
raise RuntimeError("CurrentTask.get_task() can only be called in a running task")
return task[-1]
def before_process_message(self, broker: Broker, message: Message):
tasks = self._TASKS.get()
if tasks is None:
tasks = []
tasks.append(message.options["task"])
self._TASKS.set(tasks)
def after_process_message(
self,
broker: Broker,
message: Message,
*,
result: Any | None = None,
exception: Exception | None = None,
):
tasks: list[TaskBase] | None = self._TASKS.get()
if tasks is None or len(tasks) == 0:
self.logger.warning("Task was None, not saving. This should not happen.")
return
else:
tasks[-1].save()
self._TASKS.set(tasks[:-1])