move middlewares to package
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
from django_dramatiq_postgres.middleware import CurrentTask
|
||||||
from dramatiq import actor
|
from dramatiq import actor
|
||||||
from packaging.version import parse
|
from packaging.version import parse
|
||||||
from requests import RequestException
|
from requests import RequestException
|
||||||
@ -12,7 +13,7 @@ from authentik.admin.apps import PROM_INFO
|
|||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.lib.utils.http import get_http_session
|
from authentik.lib.utils.http import get_http_session
|
||||||
from authentik.tasks.middleware import CurrentTask
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
VERSION_NULL = "0.0.0"
|
VERSION_NULL = "0.0.0"
|
||||||
@ -35,7 +36,7 @@ def _set_prom_info():
|
|||||||
@actor
|
@actor
|
||||||
def update_latest_version():
|
def update_latest_version():
|
||||||
"""Update latest version info"""
|
"""Update latest version info"""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
if CONFIG.get_bool("disable_update_check"):
|
if CONFIG.get_bool("disable_update_check"):
|
||||||
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
|
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
|
||||||
self.info("Version check disabled.")
|
self.info("Version check disabled.")
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from dacite.core import from_dict
|
|||||||
from django.db import DatabaseError, InternalError, ProgrammingError
|
from django.db import DatabaseError, InternalError, ProgrammingError
|
||||||
from django.utils.text import slugify
|
from django.utils.text import slugify
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
|
from django_dramatiq_postgres.middleware import CurrentTask
|
||||||
from dramatiq.actor import actor
|
from dramatiq.actor import actor
|
||||||
from dramatiq.middleware import Middleware
|
from dramatiq.middleware import Middleware
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
@ -35,7 +36,7 @@ from authentik.blueprints.v1.oci import OCI_PREFIX
|
|||||||
from authentik.events.logs import capture_logs
|
from authentik.events.logs import capture_logs
|
||||||
from authentik.events.utils import sanitize_dict
|
from authentik.events.utils import sanitize_dict
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.tasks.middleware import CurrentTask
|
from authentik.tasks.models import Task
|
||||||
from authentik.tasks.schedules.models import Schedule
|
from authentik.tasks.schedules.models import Schedule
|
||||||
from authentik.tenants.models import Tenant
|
from authentik.tenants.models import Tenant
|
||||||
|
|
||||||
@ -147,7 +148,7 @@ def blueprints_find() -> list[BlueprintFile]:
|
|||||||
@actor(throws=(DatabaseError, ProgrammingError, InternalError))
|
@actor(throws=(DatabaseError, ProgrammingError, InternalError))
|
||||||
def blueprints_discovery(path: str | None = None):
|
def blueprints_discovery(path: str | None = None):
|
||||||
"""Find blueprints and check if they need to be created in the database"""
|
"""Find blueprints and check if they need to be created in the database"""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
count = 0
|
count = 0
|
||||||
for blueprint in blueprints_find():
|
for blueprint in blueprints_find():
|
||||||
if path and blueprint.path != path:
|
if path and blueprint.path != path:
|
||||||
@ -187,7 +188,7 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
|
|||||||
@actor
|
@actor
|
||||||
def apply_blueprint(instance_pk: UUID):
|
def apply_blueprint(instance_pk: UUID):
|
||||||
"""Apply single blueprint"""
|
"""Apply single blueprint"""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
self.set_uid(str(instance_pk))
|
self.set_uid(str(instance_pk))
|
||||||
instance: BlueprintInstance | None = None
|
instance: BlueprintInstance | None = None
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
|
from django_dramatiq_postgres.middleware import CurrentTask
|
||||||
from dramatiq.actor import actor
|
from dramatiq.actor import actor
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
@ -12,7 +13,7 @@ from authentik.core.models import (
|
|||||||
ExpiringModel,
|
ExpiringModel,
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
from authentik.tasks.middleware import CurrentTask
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
@ -20,7 +21,7 @@ LOGGER = get_logger()
|
|||||||
@actor
|
@actor
|
||||||
def clean_expired_models():
|
def clean_expired_models():
|
||||||
"""Remove expired objects"""
|
"""Remove expired objects"""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
for cls in ExpiringModel.__subclasses__():
|
for cls in ExpiringModel.__subclasses__():
|
||||||
cls: ExpiringModel
|
cls: ExpiringModel
|
||||||
objects = (
|
objects = (
|
||||||
@ -36,7 +37,7 @@ def clean_expired_models():
|
|||||||
@actor
|
@actor
|
||||||
def clean_temporary_users():
|
def clean_temporary_users():
|
||||||
"""Remove temporary users created by SAML Sources"""
|
"""Remove temporary users created by SAML Sources"""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
_now = datetime.now()
|
_now = datetime.now()
|
||||||
deleted_users = 0
|
deleted_users = 0
|
||||||
for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}):
|
for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}):
|
||||||
|
|||||||
@ -6,12 +6,13 @@ from pathlib import Path
|
|||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
from cryptography.hazmat.primitives.serialization import load_pem_private_key
|
||||||
from cryptography.x509.base import load_pem_x509_certificate
|
from cryptography.x509.base import load_pem_x509_certificate
|
||||||
|
from django_dramatiq_postgres.middleware import CurrentTask
|
||||||
from dramatiq.actor import actor
|
from dramatiq.actor import actor
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.crypto.models import CertificateKeyPair
|
from authentik.crypto.models import CertificateKeyPair
|
||||||
from authentik.lib.config import CONFIG
|
from authentik.lib.config import CONFIG
|
||||||
from authentik.tasks.middleware import CurrentTask
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
@ -37,7 +38,7 @@ def ensure_certificate_valid(body: str):
|
|||||||
@actor
|
@actor
|
||||||
def certificate_discovery():
|
def certificate_discovery():
|
||||||
"""Discover, import and update certificates from the filesystem"""
|
"""Discover, import and update certificates from the filesystem"""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
certs = {}
|
certs = {}
|
||||||
private_keys = {}
|
private_keys = {}
|
||||||
discovered = 0
|
discovered = 0
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from django.db.models.aggregates import Count
|
from django.db.models.aggregates import Count
|
||||||
|
from django_dramatiq_postgres.middleware import CurrentTask
|
||||||
from dramatiq.actor import actor
|
from dramatiq.actor import actor
|
||||||
from structlog import get_logger
|
from structlog import get_logger
|
||||||
|
|
||||||
@ -6,7 +7,7 @@ from authentik.enterprise.policies.unique_password.models import (
|
|||||||
UniquePasswordPolicy,
|
UniquePasswordPolicy,
|
||||||
UserPasswordHistory,
|
UserPasswordHistory,
|
||||||
)
|
)
|
||||||
from authentik.tasks.middleware import CurrentTask
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
@ -16,7 +17,7 @@ def check_and_purge_password_history():
|
|||||||
"""Check if any UniquePasswordPolicy exists, and if not, purge the password history table.
|
"""Check if any UniquePasswordPolicy exists, and if not, purge the password history table.
|
||||||
This is run on a schedule instead of being triggered by policy binding deletion.
|
This is run on a schedule instead of being triggered by policy binding deletion.
|
||||||
"""
|
"""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
|
|
||||||
if not UniquePasswordPolicy.objects.exists():
|
if not UniquePasswordPolicy.objects.exists():
|
||||||
UserPasswordHistory.objects.all().delete()
|
UserPasswordHistory.objects.all().delete()
|
||||||
@ -29,7 +30,7 @@ def check_and_purge_password_history():
|
|||||||
|
|
||||||
@actor
|
@actor
|
||||||
def trim_password_histories():
|
def trim_password_histories():
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
|
|
||||||
"""Removes rows from UserPasswordHistory older than
|
"""Removes rows from UserPasswordHistory older than
|
||||||
the `n` most recent entries.
|
the `n` most recent entries.
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
|
from django_dramatiq_postgres.middleware import CurrentTask
|
||||||
from dramatiq.actor import actor
|
from dramatiq.actor import actor
|
||||||
from requests.exceptions import RequestException
|
from requests.exceptions import RequestException
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
@ -19,7 +20,7 @@ from authentik.events.logs import LogEvent
|
|||||||
from authentik.lib.utils.http import get_http_session
|
from authentik.lib.utils.http import get_http_session
|
||||||
from authentik.lib.utils.time import timedelta_from_string
|
from authentik.lib.utils.time import timedelta_from_string
|
||||||
from authentik.policies.engine import PolicyEngine
|
from authentik.policies.engine import PolicyEngine
|
||||||
from authentik.tasks.middleware import CurrentTask
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
session = get_http_session()
|
session = get_http_session()
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
@ -62,7 +63,7 @@ def _check_app_access(stream: Stream, event_data: dict) -> bool:
|
|||||||
|
|
||||||
@actor
|
@actor
|
||||||
def _send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]):
|
def _send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]):
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
|
|
||||||
stream = Stream.objects.filter(pk=stream_uuid).first()
|
stream = Stream.objects.filter(pk=stream_uuid).first()
|
||||||
if not stream:
|
if not stream:
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from django.db.models.query_utils import Q
|
from django.db.models.query_utils import Q
|
||||||
|
from django_dramatiq_postgres.middleware import CurrentTask
|
||||||
from dramatiq.actor import actor
|
from dramatiq.actor import actor
|
||||||
from guardian.shortcuts import get_anonymous_user
|
from guardian.shortcuts import get_anonymous_user
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
@ -16,7 +17,7 @@ from authentik.events.models import (
|
|||||||
)
|
)
|
||||||
from authentik.policies.engine import PolicyEngine
|
from authentik.policies.engine import PolicyEngine
|
||||||
from authentik.policies.models import PolicyBinding, PolicyEngineMode
|
from authentik.policies.models import PolicyBinding, PolicyEngineMode
|
||||||
from authentik.tasks.middleware import CurrentTask
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
@ -110,7 +111,7 @@ def gdpr_cleanup(user_pk: int):
|
|||||||
@actor
|
@actor
|
||||||
def notification_cleanup():
|
def notification_cleanup():
|
||||||
"""Cleanup seen notifications and notifications whose event expired."""
|
"""Cleanup seen notifications and notifications whose event expired."""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
notifications = Notification.objects.filter(Q(event=None) | Q(seen=True))
|
notifications = Notification.objects.filter(Q(event=None) | Q(seen=True))
|
||||||
amount = notifications.count()
|
amount = notifications.count()
|
||||||
notifications.delete()
|
notifications.delete()
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from django.core.paginator import Paginator
|
|||||||
from django.db.models import Model, QuerySet
|
from django.db.models import Model, QuerySet
|
||||||
from django.db.models.query import Q
|
from django.db.models.query import Q
|
||||||
from django.utils.text import slugify
|
from django.utils.text import slugify
|
||||||
|
from django_dramatiq_postgres.middleware import CurrentTask
|
||||||
from dramatiq.actor import Actor
|
from dramatiq.actor import Actor
|
||||||
from dramatiq.composition import group
|
from dramatiq.composition import group
|
||||||
from dramatiq.errors import Retry
|
from dramatiq.errors import Retry
|
||||||
@ -20,7 +21,6 @@ from authentik.lib.sync.outgoing.exceptions import (
|
|||||||
)
|
)
|
||||||
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
|
||||||
from authentik.lib.utils.reflection import class_to_path, path_to_class
|
from authentik.lib.utils.reflection import class_to_path, path_to_class
|
||||||
from authentik.tasks.middleware import CurrentTask
|
|
||||||
from authentik.tasks.models import Task
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ class SyncTasks:
|
|||||||
provider_pk: int,
|
provider_pk: int,
|
||||||
sync_objects: Actor,
|
sync_objects: Actor,
|
||||||
):
|
):
|
||||||
task = CurrentTask.get_task()
|
task: Task = CurrentTask.get_task()
|
||||||
self.logger = get_logger().bind(
|
self.logger = get_logger().bind(
|
||||||
provider_type=class_to_path(self._provider_model),
|
provider_type=class_to_path(self._provider_model),
|
||||||
provider_pk=provider_pk,
|
provider_pk=provider_pk,
|
||||||
@ -114,7 +114,7 @@ class SyncTasks:
|
|||||||
override_dry_run=False,
|
override_dry_run=False,
|
||||||
**filter,
|
**filter,
|
||||||
):
|
):
|
||||||
task = CurrentTask.get_task()
|
task: Task = CurrentTask.get_task()
|
||||||
_object_type: type[Model] = path_to_class(object_type)
|
_object_type: type[Model] = path_to_class(object_type)
|
||||||
self.logger = get_logger().bind(
|
self.logger = get_logger().bind(
|
||||||
provider_type=class_to_path(self._provider_model),
|
provider_type=class_to_path(self._provider_model),
|
||||||
@ -186,7 +186,7 @@ class SyncTasks:
|
|||||||
provider_pk: int,
|
provider_pk: int,
|
||||||
raw_op: str,
|
raw_op: str,
|
||||||
):
|
):
|
||||||
task = CurrentTask.get_task()
|
task: Task = CurrentTask.get_task()
|
||||||
self.logger = get_logger().bind(
|
self.logger = get_logger().bind(
|
||||||
provider_type=class_to_path(self._provider_model),
|
provider_type=class_to_path(self._provider_model),
|
||||||
)
|
)
|
||||||
@ -234,7 +234,7 @@ class SyncTasks:
|
|||||||
action: str,
|
action: str,
|
||||||
pk_set: list[int],
|
pk_set: list[int],
|
||||||
):
|
):
|
||||||
task = CurrentTask.get_task()
|
task: Task = CurrentTask.get_task()
|
||||||
self.logger = get_logger().bind(
|
self.logger = get_logger().bind(
|
||||||
provider_type=class_to_path(self._provider_model),
|
provider_type=class_to_path(self._provider_model),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from channels.layers import get_channel_layer
|
|||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from django.db.models.base import Model
|
from django.db.models.base import Model
|
||||||
from django.utils.text import slugify
|
from django.utils.text import slugify
|
||||||
|
from django_dramatiq_postgres.middleware import CurrentTask
|
||||||
from docker.constants import DEFAULT_UNIX_SOCKET
|
from docker.constants import DEFAULT_UNIX_SOCKET
|
||||||
from dramatiq.actor import actor
|
from dramatiq.actor import actor
|
||||||
from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME
|
from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME
|
||||||
@ -42,7 +43,7 @@ from authentik.providers.rac.controllers.docker import RACDockerController
|
|||||||
from authentik.providers.rac.controllers.kubernetes import RACKubernetesController
|
from authentik.providers.rac.controllers.kubernetes import RACKubernetesController
|
||||||
from authentik.providers.radius.controllers.docker import RadiusDockerController
|
from authentik.providers.radius.controllers.docker import RadiusDockerController
|
||||||
from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController
|
from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController
|
||||||
from authentik.tasks.middleware import CurrentTask
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s"
|
CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s"
|
||||||
@ -109,7 +110,7 @@ def outpost_service_connection_monitor(connection_pk: Any):
|
|||||||
@actor
|
@actor
|
||||||
def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False):
|
def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False):
|
||||||
"""Create/update/monitor/delete the deployment of an Outpost"""
|
"""Create/update/monitor/delete the deployment of an Outpost"""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
self.set_uid(outpost_pk)
|
self.set_uid(outpost_pk)
|
||||||
logs = []
|
logs = []
|
||||||
if from_cache:
|
if from_cache:
|
||||||
@ -144,7 +145,7 @@ def outpost_token_ensurer():
|
|||||||
"""
|
"""
|
||||||
Periodically ensure that all Outposts have valid Service Accounts and Tokens
|
Periodically ensure that all Outposts have valid Service Accounts and Tokens
|
||||||
"""
|
"""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
all_outposts = Outpost.objects.all()
|
all_outposts = Outpost.objects.all()
|
||||||
for outpost in all_outposts:
|
for outpost in all_outposts:
|
||||||
_ = outpost.token
|
_ = outpost.token
|
||||||
@ -227,7 +228,7 @@ def _outpost_single_update(outpost: Outpost, layer=None):
|
|||||||
@actor
|
@actor
|
||||||
def outpost_connection_discovery():
|
def outpost_connection_discovery():
|
||||||
"""Checks the local environment and create Service connections."""
|
"""Checks the local environment and create Service connections."""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
if not CONFIG.get_bool("outposts.discover"):
|
if not CONFIG.get_bool("outposts.discover"):
|
||||||
self.info("Outpost integration discovery is disabled")
|
self.info("Outpost integration discovery is disabled")
|
||||||
return
|
return
|
||||||
|
|||||||
@ -367,6 +367,7 @@ DRAMATIQ = {
|
|||||||
"threads": CONFIG.get_int("worker.threads", 1),
|
"threads": CONFIG.get_int("worker.threads", 1),
|
||||||
},
|
},
|
||||||
"middlewares": (
|
"middlewares": (
|
||||||
|
("django_dramatiq_postgres.middleware.FullyQualifiedActorName", {}),
|
||||||
# TODO: fixme
|
# TODO: fixme
|
||||||
# ("dramatiq.middleware.prometheus.Prometheus", {}),
|
# ("dramatiq.middleware.prometheus.Prometheus", {}),
|
||||||
("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}),
|
("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}),
|
||||||
@ -384,10 +385,9 @@ DRAMATIQ = {
|
|||||||
{"max_retries": 20 if not TEST else 0},
|
{"max_retries": 20 if not TEST else 0},
|
||||||
),
|
),
|
||||||
# TODO: results
|
# TODO: results
|
||||||
("authentik.tasks.middleware.FullyQualifiedActorName", {}),
|
|
||||||
("authentik.tasks.middleware.RelObjMiddleware", {}),
|
|
||||||
("authentik.tasks.middleware.TenantMiddleware", {}),
|
("authentik.tasks.middleware.TenantMiddleware", {}),
|
||||||
("authentik.tasks.middleware.CurrentTask", {}),
|
("authentik.tasks.middleware.RelObjMiddleware", {}),
|
||||||
|
("django_dramatiq_postgres.middleware.CurrentTask", {}),
|
||||||
),
|
),
|
||||||
"test": TEST,
|
"test": TEST,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
"""Kerberos Sync tasks"""
|
"""Kerberos Sync tasks"""
|
||||||
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
|
from django_dramatiq_postgres.middleware import CurrentTask
|
||||||
from dramatiq.actor import actor
|
from dramatiq.actor import actor
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
@ -9,7 +10,7 @@ from authentik.lib.sync.outgoing.exceptions import StopSync
|
|||||||
from authentik.lib.utils.errors import exception_to_string
|
from authentik.lib.utils.errors import exception_to_string
|
||||||
from authentik.sources.kerberos.models import KerberosSource
|
from authentik.sources.kerberos.models import KerberosSource
|
||||||
from authentik.sources.kerberos.sync import KerberosSync
|
from authentik.sources.kerberos.sync import KerberosSync
|
||||||
from authentik.tasks.middleware import CurrentTask
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
CACHE_KEY_STATUS = "goauthentik.io/sources/kerberos/status/"
|
CACHE_KEY_STATUS = "goauthentik.io/sources/kerberos/status/"
|
||||||
@ -30,7 +31,7 @@ def kerberos_connectivity_check(pk: str):
|
|||||||
@actor(time_limit=(60 * 60 * CONFIG.get_int("sources.kerberos.task_timeout_hours")) * 2.5 * 1000)
|
@actor(time_limit=(60 * 60 * CONFIG.get_int("sources.kerberos.task_timeout_hours")) * 2.5 * 1000)
|
||||||
def kerberos_sync(pk: str):
|
def kerberos_sync(pk: str):
|
||||||
"""Sync a single source"""
|
"""Sync a single source"""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
source: KerberosSource = KerberosSource.objects.filter(enabled=True, pk=pk).first()
|
source: KerberosSource = KerberosSource.objects.filter(enabled=True, pk=pk).first()
|
||||||
if not source:
|
if not source:
|
||||||
return
|
return
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
|
from django_dramatiq_postgres.middleware import CurrentTask
|
||||||
from dramatiq.actor import actor
|
from dramatiq.actor import actor
|
||||||
from dramatiq.composition import group
|
from dramatiq.composition import group
|
||||||
from dramatiq.message import Message
|
from dramatiq.message import Message
|
||||||
@ -20,7 +21,7 @@ from authentik.sources.ldap.sync.forward_delete_users import UserLDAPForwardDele
|
|||||||
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
|
from authentik.sources.ldap.sync.groups import GroupLDAPSynchronizer
|
||||||
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
|
from authentik.sources.ldap.sync.membership import MembershipLDAPSynchronizer
|
||||||
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
|
from authentik.sources.ldap.sync.users import UserLDAPSynchronizer
|
||||||
from authentik.tasks.middleware import CurrentTask
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
SYNC_CLASSES = [
|
SYNC_CLASSES = [
|
||||||
@ -118,7 +119,7 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) ->
|
|||||||
@actor(time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000)
|
@actor(time_limit=60 * 60 * CONFIG.get_int("ldap.task_timeout_hours") * 1000)
|
||||||
def ldap_sync_page(source_pk: str, sync_class: str, page_cache_key: str):
|
def ldap_sync_page(source_pk: str, sync_class: str, page_cache_key: str):
|
||||||
"""Synchronization of an LDAP Source"""
|
"""Synchronization of an LDAP Source"""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
|
source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first()
|
||||||
if not source:
|
if not source:
|
||||||
# Because the source couldn't be found, we don't have a UID
|
# Because the source couldn't be found, we don't have a UID
|
||||||
|
|||||||
@ -2,13 +2,14 @@
|
|||||||
|
|
||||||
from json import dumps
|
from json import dumps
|
||||||
|
|
||||||
|
from django_dramatiq_postgres.middleware import CurrentTask
|
||||||
from dramatiq.actor import actor
|
from dramatiq.actor import actor
|
||||||
from requests import RequestException
|
from requests import RequestException
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.lib.utils.http import get_http_session
|
from authentik.lib.utils.http import get_http_session
|
||||||
from authentik.sources.oauth.models import OAuthSource
|
from authentik.sources.oauth.models import OAuthSource
|
||||||
from authentik.tasks.middleware import CurrentTask
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
@ -16,7 +17,7 @@ LOGGER = get_logger()
|
|||||||
@actor
|
@actor
|
||||||
def update_well_known_jwks():
|
def update_well_known_jwks():
|
||||||
"""Update OAuth sources' config from well_known, and JWKS info from the configured URL"""
|
"""Update OAuth sources' config from well_known, and JWKS info from the configured URL"""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
session = get_http_session()
|
session = get_http_session()
|
||||||
for source in OAuthSource.objects.all().exclude(oidc_well_known_url=""):
|
for source in OAuthSource.objects.all().exclude(oidc_well_known_url=""):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
"""Plex tasks"""
|
"""Plex tasks"""
|
||||||
|
|
||||||
|
from django_dramatiq_postgres.middleware import CurrentTask
|
||||||
from dramatiq.actor import actor
|
from dramatiq.actor import actor
|
||||||
from requests import RequestException
|
from requests import RequestException
|
||||||
|
|
||||||
@ -7,13 +8,13 @@ from authentik.events.models import Event, EventAction
|
|||||||
from authentik.lib.utils.errors import exception_to_string
|
from authentik.lib.utils.errors import exception_to_string
|
||||||
from authentik.sources.plex.models import PlexSource
|
from authentik.sources.plex.models import PlexSource
|
||||||
from authentik.sources.plex.plex import PlexAuth
|
from authentik.sources.plex.plex import PlexAuth
|
||||||
from authentik.tasks.middleware import CurrentTask
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
|
|
||||||
@actor
|
@actor
|
||||||
def check_plex_token(source_pk: str):
|
def check_plex_token(source_pk: str):
|
||||||
"""Check the validity of a Plex source."""
|
"""Check the validity of a Plex source."""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
sources = PlexSource.objects.filter(pk=source_pk)
|
sources = PlexSource.objects.filter(pk=source_pk)
|
||||||
if not sources.exists():
|
if not sources.exists():
|
||||||
return
|
return
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from django.db.transaction import atomic
|
from django.db.transaction import atomic
|
||||||
|
from django_dramatiq_postgres.middleware import CurrentTask
|
||||||
from dramatiq.actor import actor
|
from dramatiq.actor import actor
|
||||||
from fido2.mds3 import filter_revoked, parse_blob
|
from fido2.mds3 import filter_revoked, parse_blob
|
||||||
|
|
||||||
@ -13,7 +14,7 @@ from authentik.stages.authenticator_webauthn.models import (
|
|||||||
UNKNOWN_DEVICE_TYPE_AAGUID,
|
UNKNOWN_DEVICE_TYPE_AAGUID,
|
||||||
WebAuthnDeviceType,
|
WebAuthnDeviceType,
|
||||||
)
|
)
|
||||||
from authentik.tasks.middleware import CurrentTask
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
CACHE_KEY_MDS_NO = "goauthentik.io/stages/authenticator_webauthn/mds_no"
|
CACHE_KEY_MDS_NO = "goauthentik.io/stages/authenticator_webauthn/mds_no"
|
||||||
AAGUID_BLOB_PATH = Path(__file__).parent / "mds" / "aaguid.json"
|
AAGUID_BLOB_PATH = Path(__file__).parent / "mds" / "aaguid.json"
|
||||||
@ -31,7 +32,7 @@ def mds_ca() -> bytes:
|
|||||||
@actor
|
@actor
|
||||||
def webauthn_mds_import(force=False):
|
def webauthn_mds_import(force=False):
|
||||||
"""Background task to import FIDO Alliance MDS blob and AAGUIDs into database"""
|
"""Background task to import FIDO Alliance MDS blob and AAGUIDs into database"""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
with open(MDS_BLOB_PATH, mode="rb") as _raw_blob:
|
with open(MDS_BLOB_PATH, mode="rb") as _raw_blob:
|
||||||
blob = parse_blob(_raw_blob.read(), mds_ca())
|
blob = parse_blob(_raw_blob.read(), mds_ca())
|
||||||
to_create_update = [
|
to_create_update = [
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import Any
|
|||||||
from django.core.mail import EmailMultiAlternatives
|
from django.core.mail import EmailMultiAlternatives
|
||||||
from django.core.mail.utils import DNS_NAME
|
from django.core.mail.utils import DNS_NAME
|
||||||
from django.utils.text import slugify
|
from django.utils.text import slugify
|
||||||
|
from django_dramatiq_postgres.middleware import CurrentTask
|
||||||
from dramatiq.actor import actor
|
from dramatiq.actor import actor
|
||||||
from dramatiq.composition import group
|
from dramatiq.composition import group
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
@ -15,7 +16,7 @@ from authentik.lib.utils.reflection import class_to_path, path_to_class
|
|||||||
from authentik.stages.authenticator_email.models import AuthenticatorEmailStage
|
from authentik.stages.authenticator_email.models import AuthenticatorEmailStage
|
||||||
from authentik.stages.email.models import EmailStage
|
from authentik.stages.email.models import EmailStage
|
||||||
from authentik.stages.email.utils import logo_data
|
from authentik.stages.email.utils import logo_data
|
||||||
from authentik.tasks.middleware import CurrentTask
|
from authentik.tasks.models import Task
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
@ -54,7 +55,7 @@ def send_mail(
|
|||||||
email_stage_pk: str | None = None,
|
email_stage_pk: str | None = None,
|
||||||
):
|
):
|
||||||
"""Send Email for Email Stage. Retries are scheduled automatically."""
|
"""Send Email for Email Stage. Retries are scheduled automatically."""
|
||||||
self = CurrentTask.get_task()
|
self: Task = CurrentTask.get_task()
|
||||||
message_id = make_msgid(domain=DNS_NAME)
|
message_id = make_msgid(domain=DNS_NAME)
|
||||||
self.set_uid(slugify(message_id.replace(".", "_").replace("@", "_")))
|
self.set_uid(slugify(message_id.replace(".", "_").replace("@", "_")))
|
||||||
if not stage_class_path or not email_stage_pk:
|
if not stage_class_path or not email_stage_pk:
|
||||||
|
|||||||
@ -1,21 +1,11 @@
|
|||||||
from typing import Any
|
from django.db.models import QuerySet
|
||||||
|
|
||||||
from django_dramatiq_postgres.broker import PostgresBroker
|
from django_dramatiq_postgres.broker import PostgresBroker
|
||||||
from dramatiq.message import Message
|
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.tenants.utils import get_current_tenant
|
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class Broker(PostgresBroker):
|
class Broker(PostgresBroker):
|
||||||
def model_defaults(self, message: Message) -> dict[str, Any]:
|
@property
|
||||||
rel_obj = message.options.get("rel_obj")
|
def query_set(self) -> QuerySet:
|
||||||
if rel_obj:
|
return self.model.objects.select_related("tenant").using(self.db_alias)
|
||||||
del message.options["rel_obj"]
|
|
||||||
return {
|
|
||||||
"tenant": get_current_tenant(),
|
|
||||||
"rel_obj": rel_obj,
|
|
||||||
**super().model_defaults(message),
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,16 +1,22 @@
|
|||||||
import contextvars
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from dramatiq.actor import Actor
|
|
||||||
from dramatiq.broker import Broker
|
from dramatiq.broker import Broker
|
||||||
from dramatiq.message import Message
|
from dramatiq.message import Message
|
||||||
from dramatiq.middleware import Middleware
|
from dramatiq.middleware import Middleware
|
||||||
from structlog.stdlib import get_logger
|
|
||||||
|
|
||||||
from authentik.tasks.models import Task
|
from authentik.tasks.models import Task
|
||||||
from authentik.tenants.models import Tenant
|
from authentik.tenants.models import Tenant
|
||||||
|
from authentik.tenants.utils import get_current_tenant
|
||||||
|
|
||||||
LOGGER = get_logger()
|
|
||||||
|
class TenantMiddleware(Middleware):
|
||||||
|
def before_enqueue(self, broker: Broker, message: Message, delay: int):
|
||||||
|
message.options["model_defaults"]["tenant"] = get_current_tenant()
|
||||||
|
|
||||||
|
def before_process_message(self, broker: Broker, message: Message):
|
||||||
|
task: Task = message.options["task"]
|
||||||
|
task.tenant.activate()
|
||||||
|
|
||||||
|
def after_process_message(self, *args, **kwargs):
|
||||||
|
Tenant.deactivate()
|
||||||
|
|
||||||
|
|
||||||
class RelObjMiddleware(Middleware):
|
class RelObjMiddleware(Middleware):
|
||||||
@ -18,52 +24,7 @@ class RelObjMiddleware(Middleware):
|
|||||||
def actor_options(self):
|
def actor_options(self):
|
||||||
return {"rel_obj"}
|
return {"rel_obj"}
|
||||||
|
|
||||||
|
def before_enqueue(self, broker: Broker, message: Message, delay: int):
|
||||||
class FullyQualifiedActorName(Middleware):
|
if rel_obj := message.options.get("rel_obj"):
|
||||||
def before_declare_actor(self, broker: Broker, actor: Actor):
|
del message.options["rel_obj"]
|
||||||
actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}"
|
message.options["model_defaults"]["rel_obj"] = rel_obj
|
||||||
|
|
||||||
|
|
||||||
class CurrentTask(Middleware):
|
|
||||||
_TASK: contextvars.ContextVar[list[Task] | None] = contextvars.ContextVar(
|
|
||||||
"_TASK",
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_task(cls) -> Task:
|
|
||||||
task = cls._TASK.get()
|
|
||||||
if not task:
|
|
||||||
raise RuntimeError("CurrentTask.get_task() should only be called in a running task")
|
|
||||||
return task[-1]
|
|
||||||
|
|
||||||
def before_process_message(self, broker: Broker, message: Message):
|
|
||||||
tasks = self._TASK.get()
|
|
||||||
if tasks is None:
|
|
||||||
tasks = []
|
|
||||||
tasks.append(Task.objects.get(message_id=message.message_id))
|
|
||||||
self._TASK.set(tasks)
|
|
||||||
|
|
||||||
def after_process_message(
|
|
||||||
self,
|
|
||||||
broker: Broker,
|
|
||||||
message: Message,
|
|
||||||
*,
|
|
||||||
result: Any | None = None,
|
|
||||||
exception: Exception | None = None,
|
|
||||||
):
|
|
||||||
tasks: list[Task] | None = self._TASK.get()
|
|
||||||
if tasks is None or len(tasks) == 0:
|
|
||||||
LOGGER.warn("Task was None, not saving. This should not happen")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
tasks[-1].save()
|
|
||||||
self._TASK.set(tasks[:-1])
|
|
||||||
|
|
||||||
|
|
||||||
class TenantMiddleware(Middleware):
|
|
||||||
def before_process_message(self, broker: Broker, message: Message):
|
|
||||||
Task.objects.select_related("tenant").get(message_id=message.message_id).tenant.activate()
|
|
||||||
|
|
||||||
def after_process_message(self, *args, **kwargs):
|
|
||||||
Tenant.deactivate()
|
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from django.utils.module_loading import import_string
|
|||||||
from dramatiq.broker import Broker, Consumer, MessageProxy
|
from dramatiq.broker import Broker, Consumer, MessageProxy
|
||||||
from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
|
from dramatiq.common import compute_backoff, current_millis, dq_name, xq_name
|
||||||
from dramatiq.errors import ConnectionError, QueueJoinTimeout
|
from dramatiq.errors import ConnectionError, QueueJoinTimeout
|
||||||
|
from dramatiq.logging import get_logger
|
||||||
from dramatiq.message import Message
|
from dramatiq.message import Message
|
||||||
from dramatiq.middleware import (
|
from dramatiq.middleware import (
|
||||||
Middleware,
|
Middleware,
|
||||||
@ -30,12 +31,11 @@ from dramatiq.middleware import (
|
|||||||
from pglock.core import _cast_lock_id
|
from pglock.core import _cast_lock_id
|
||||||
from psycopg import Notify, sql
|
from psycopg import Notify, sql
|
||||||
from psycopg.errors import AdminShutdown
|
from psycopg.errors import AdminShutdown
|
||||||
from structlog.stdlib import get_logger
|
|
||||||
|
|
||||||
from django_dramatiq_postgres.conf import Conf
|
from django_dramatiq_postgres.conf import Conf
|
||||||
from django_dramatiq_postgres.models import CHANNEL_PREFIX, ChannelIdentifier, TaskBase, TaskState
|
from django_dramatiq_postgres.models import CHANNEL_PREFIX, ChannelIdentifier, TaskBase, TaskState
|
||||||
|
|
||||||
LOGGER = get_logger()
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
|
def channel_name(queue_name: str, identifier: ChannelIdentifier) -> str:
|
||||||
@ -62,7 +62,7 @@ class PostgresBroker(Broker):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(*args, middleware=[], **kwargs)
|
super().__init__(*args, middleware=[], **kwargs)
|
||||||
self.logger = get_logger().bind()
|
self.logger = get_logger(__name__, type(self))
|
||||||
|
|
||||||
self.queues = set()
|
self.queues = set()
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ class PostgresBroker(Broker):
|
|||||||
reraise=True,
|
reraise=True,
|
||||||
wait=tenacity.wait_random_exponential(multiplier=1, max=30),
|
wait=tenacity.wait_random_exponential(multiplier=1, max=30),
|
||||||
stop=tenacity.stop_after_attempt(10),
|
stop=tenacity.stop_after_attempt(10),
|
||||||
before_sleep=tenacity.before_sleep_log(LOGGER, logging.INFO, exc_info=True),
|
before_sleep=tenacity.before_sleep_log(logger, logging.INFO, exc_info=True),
|
||||||
)
|
)
|
||||||
def enqueue(self, message: Message, *, delay: int | None = None) -> Message:
|
def enqueue(self, message: Message, *, delay: int | None = None) -> Message:
|
||||||
canonical_queue_name = message.queue_name
|
canonical_queue_name = message.queue_name
|
||||||
@ -148,20 +148,26 @@ class PostgresBroker(Broker):
|
|||||||
|
|
||||||
self.declare_queue(canonical_queue_name)
|
self.declare_queue(canonical_queue_name)
|
||||||
self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}")
|
self.logger.debug(f"Enqueueing message {message.message_id} on queue {queue_name}")
|
||||||
|
|
||||||
|
message.options["model_defaults"] = self.model_defaults(message)
|
||||||
self.emit_before("enqueue", message, delay)
|
self.emit_before("enqueue", message, delay)
|
||||||
|
|
||||||
query = {
|
query = {
|
||||||
"message_id": message.message_id,
|
"message_id": message.message_id,
|
||||||
}
|
}
|
||||||
defaults = self.model_defaults(message)
|
defaults = message.options["model_defaults"]
|
||||||
|
del message.options["model_defaults"]
|
||||||
create_defaults = {
|
create_defaults = {
|
||||||
**query,
|
**query,
|
||||||
**defaults,
|
**defaults,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.query_set.update_or_create(
|
self.query_set.update_or_create(
|
||||||
**query,
|
**query,
|
||||||
defaults=defaults,
|
defaults=defaults,
|
||||||
create_defaults=create_defaults,
|
create_defaults=create_defaults,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.emit_after("enqueue", message, delay)
|
self.emit_after("enqueue", message, delay)
|
||||||
return message
|
return message
|
||||||
|
|
||||||
@ -209,7 +215,7 @@ class _PostgresConsumer(Consumer):
|
|||||||
timeout: int,
|
timeout: int,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.logger = get_logger().bind()
|
self.logger = get_logger(__name__, type(self))
|
||||||
|
|
||||||
self.notifies: list[Notify] = []
|
self.notifies: list[Notify] = []
|
||||||
self.broker = broker
|
self.broker = broker
|
||||||
|
|||||||
@ -1,10 +1,18 @@
|
|||||||
|
import contextvars
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from django.db import (
|
from django.db import (
|
||||||
close_old_connections,
|
close_old_connections,
|
||||||
connections,
|
connections,
|
||||||
)
|
)
|
||||||
|
from dramatiq.actor import Actor
|
||||||
|
from dramatiq.broker import Broker
|
||||||
|
from dramatiq.logging import get_logger
|
||||||
|
from dramatiq.message import Message
|
||||||
from dramatiq.middleware.middleware import Middleware
|
from dramatiq.middleware.middleware import Middleware
|
||||||
|
|
||||||
from django_dramatiq_postgres.conf import Conf
|
from django_dramatiq_postgres.conf import Conf
|
||||||
|
from django_dramatiq_postgres.models import TaskBase
|
||||||
|
|
||||||
|
|
||||||
class DbConnectionMiddleware(Middleware):
|
class DbConnectionMiddleware(Middleware):
|
||||||
@ -22,3 +30,49 @@ class DbConnectionMiddleware(Middleware):
|
|||||||
before_consumer_thread_shutdown = _close_connections
|
before_consumer_thread_shutdown = _close_connections
|
||||||
before_worker_thread_shutdown = _close_connections
|
before_worker_thread_shutdown = _close_connections
|
||||||
before_worker_shutdown = _close_connections
|
before_worker_shutdown = _close_connections
|
||||||
|
|
||||||
|
|
||||||
|
class FullyQualifiedActorName(Middleware):
|
||||||
|
def before_declare_actor(self, broker: Broker, actor: Actor):
|
||||||
|
actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}"
|
||||||
|
|
||||||
|
|
||||||
|
class CurrentTask(Middleware):
|
||||||
|
def __init__(self):
|
||||||
|
self.logger = get_logger(__name__, type(self))
|
||||||
|
|
||||||
|
# This is a list of tasks, so that in tests, when a task calls another task, this acts as a pile
|
||||||
|
_TASKS: contextvars.ContextVar[list[TaskBase] | None] = contextvars.ContextVar(
|
||||||
|
"_TASKS",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_task(cls) -> TaskBase:
|
||||||
|
task = cls._TASKS.get()
|
||||||
|
if not task:
|
||||||
|
raise RuntimeError("CurrentTask.get_task() can only be called in a running task")
|
||||||
|
return task[-1]
|
||||||
|
|
||||||
|
def before_process_message(self, broker: Broker, message: Message):
|
||||||
|
tasks = self._TASKS.get()
|
||||||
|
if tasks is None:
|
||||||
|
tasks = []
|
||||||
|
tasks.append(message.options["task"])
|
||||||
|
self._TASKS.set(tasks)
|
||||||
|
|
||||||
|
def after_process_message(
|
||||||
|
self,
|
||||||
|
broker: Broker,
|
||||||
|
message: Message,
|
||||||
|
*,
|
||||||
|
result: Any | None = None,
|
||||||
|
exception: Exception | None = None,
|
||||||
|
):
|
||||||
|
tasks: list[TaskBase] | None = self._TASKS.get()
|
||||||
|
if tasks is None or len(tasks) == 0:
|
||||||
|
self.logger.warning("Task was None, not saving. This should not happen.")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
tasks[-1].save()
|
||||||
|
self._TASKS.set(tasks[:-1])
|
||||||
|
|||||||
Reference in New Issue
Block a user