Compare commits

..

5 Commits

Author SHA1 Message Date
cddc1c0478 WIP 2025-07-01 19:28:30 +03:00
eab6e288d7 core: bump lxml from 5.4.0 to 6.0.0 (#15281)
Bumps [lxml](https://github.com/lxml/lxml) from 5.4.0 to 6.0.0.
- [Release notes](https://github.com/lxml/lxml/releases)
- [Changelog](https://github.com/lxml/lxml/blob/master/CHANGES.txt)
- [Commits](https://github.com/lxml/lxml/compare/lxml-5.4.0...lxml-6.0.0)

---
updated-dependencies:
- dependency-name: lxml
  dependency-version: 6.0.0
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-07-01 17:19:01 +02:00
91c2863358 website: bump @types/node from 24.0.7 to 24.0.8 in /website (#15328)
Bumps [@types/node](https://github.com/DefinitelyTyped/DefinitelyTyped/tree/HEAD/types/node) from 24.0.7 to 24.0.8.
- [Release notes](https://github.com/DefinitelyTyped/DefinitelyTyped/releases)
- [Commits](https://github.com/DefinitelyTyped/DefinitelyTyped/commits/HEAD/types/node)

---
updated-dependencies:
- dependency-name: "@types/node"
  dependency-version: 24.0.8
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-07-01 17:01:24 +02:00
1638e95bc7 website: bump the build group in /website with 3 updates (#15279)
Bumps the build group in /website with 3 updates: [@rspack/binding-darwin-arm64](https://github.com/web-infra-dev/rspack/tree/HEAD/packages/rspack), [@rspack/binding-linux-arm64-gnu](https://github.com/web-infra-dev/rspack/tree/HEAD/packages/rspack) and [@rspack/binding-linux-x64-gnu](https://github.com/web-infra-dev/rspack/tree/HEAD/packages/rspack).


Updates `@rspack/binding-darwin-arm64` from 1.3.15 to 1.4.0
- [Release notes](https://github.com/web-infra-dev/rspack/releases)
- [Commits](https://github.com/web-infra-dev/rspack/commits/v1.4.0/packages/rspack)

Updates `@rspack/binding-linux-arm64-gnu` from 1.3.15 to 1.4.0
- [Release notes](https://github.com/web-infra-dev/rspack/releases)
- [Commits](https://github.com/web-infra-dev/rspack/commits/v1.4.0/packages/rspack)

Updates `@rspack/binding-linux-x64-gnu` from 1.3.15 to 1.4.0
- [Release notes](https://github.com/web-infra-dev/rspack/releases)
- [Commits](https://github.com/web-infra-dev/rspack/commits/v1.4.0/packages/rspack)

---
updated-dependencies:
- dependency-name: "@rspack/binding-darwin-arm64"
  dependency-version: 1.4.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: build
- dependency-name: "@rspack/binding-linux-arm64-gnu"
  dependency-version: 1.4.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: build
- dependency-name: "@rspack/binding-linux-x64-gnu"
  dependency-version: 1.4.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: build
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-07-01 17:00:46 +02:00
8f75131541 website: bump the eslint group in /website with 3 updates (#15329)
Bumps the eslint group in /website with 3 updates: [@typescript-eslint/eslint-plugin](https://github.com/typescript-eslint/typescript-eslint/tree/HEAD/packages/eslint-plugin), [@typescript-eslint/parser](https://github.com/typescript-eslint/typescript-eslint/tree/HEAD/packages/parser) and [typescript-eslint](https://github.com/typescript-eslint/typescript-eslint/tree/HEAD/packages/typescript-eslint).


Updates `@typescript-eslint/eslint-plugin` from 8.35.0 to 8.35.1
- [Release notes](https://github.com/typescript-eslint/typescript-eslint/releases)
- [Changelog](https://github.com/typescript-eslint/typescript-eslint/blob/main/packages/eslint-plugin/CHANGELOG.md)
- [Commits](https://github.com/typescript-eslint/typescript-eslint/commits/v8.35.1/packages/eslint-plugin)

Updates `@typescript-eslint/parser` from 8.35.0 to 8.35.1
- [Release notes](https://github.com/typescript-eslint/typescript-eslint/releases)
- [Changelog](https://github.com/typescript-eslint/typescript-eslint/blob/main/packages/parser/CHANGELOG.md)
- [Commits](https://github.com/typescript-eslint/typescript-eslint/commits/v8.35.1/packages/parser)

Updates `typescript-eslint` from 8.35.0 to 8.35.1
- [Release notes](https://github.com/typescript-eslint/typescript-eslint/releases)
- [Changelog](https://github.com/typescript-eslint/typescript-eslint/blob/main/packages/typescript-eslint/CHANGELOG.md)
- [Commits](https://github.com/typescript-eslint/typescript-eslint/commits/v8.35.1/packages/typescript-eslint)

---
updated-dependencies:
- dependency-name: "@typescript-eslint/eslint-plugin"
  dependency-version: 8.35.1
  dependency-type: direct:development
  update-type: version-update:semver-patch
  dependency-group: eslint
- dependency-name: "@typescript-eslint/parser"
  dependency-version: 8.35.1
  dependency-type: direct:development
  update-type: version-update:semver-patch
  dependency-group: eslint
- dependency-name: typescript-eslint
  dependency-version: 8.35.1
  dependency-type: direct:development
  update-type: version-update:semver-patch
  dependency-group: eslint
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-07-01 17:00:00 +02:00
232 changed files with 3918 additions and 6696 deletions

5
.gitignore vendored
View File

@ -100,6 +100,9 @@ ipython_config.py
# pyenv # pyenv
.python-version .python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files # SageMath parsed files
*.sage.py *.sage.py
@ -163,6 +166,8 @@ dmypy.json
# pyenv # pyenv
# celery beat schedule file
# SageMath parsed files # SageMath parsed files
# Environments # Environments

View File

@ -122,7 +122,6 @@ ENV UV_NO_BINARY_PACKAGE="cryptography lxml python-kadmin-rs xmlsec"
RUN --mount=type=bind,target=pyproject.toml,src=pyproject.toml \ RUN --mount=type=bind,target=pyproject.toml,src=pyproject.toml \
--mount=type=bind,target=uv.lock,src=uv.lock \ --mount=type=bind,target=uv.lock,src=uv.lock \
--mount=type=bind,target=packages,src=packages \
--mount=type=cache,target=/root/.cache/uv \ --mount=type=cache,target=/root/.cache/uv \
uv sync --frozen --no-install-project --no-dev uv sync --frozen --no-install-project --no-dev
@ -168,7 +167,6 @@ COPY ./blueprints /blueprints
COPY ./lifecycle/ /lifecycle COPY ./lifecycle/ /lifecycle
COPY ./authentik/sources/kerberos/krb5.conf /etc/krb5.conf COPY ./authentik/sources/kerberos/krb5.conf /etc/krb5.conf
COPY --from=go-builder /go/authentik /bin/authentik COPY --from=go-builder /go/authentik /bin/authentik
COPY ./packages/ /ak-root/packages
COPY --from=python-deps /ak-root/.venv /ak-root/.venv COPY --from=python-deps /ak-root/.venv /ak-root/.venv
COPY --from=node-builder /work/web/dist/ /web/dist/ COPY --from=node-builder /work/web/dist/ /web/dist/
COPY --from=node-builder /work/web/authentik/ /web/authentik/ COPY --from=node-builder /work/web/authentik/ /web/authentik/

View File

@ -6,7 +6,7 @@ PWD = $(shell pwd)
UID = $(shell id -u) UID = $(shell id -u)
GID = $(shell id -g) GID = $(shell id -g)
NPM_VERSION = $(shell python -m scripts.generate_semver) NPM_VERSION = $(shell python -m scripts.generate_semver)
PY_SOURCES = authentik packages tests scripts lifecycle .github PY_SOURCES = authentik tests scripts lifecycle .github
DOCKER_IMAGE ?= "authentik:test" DOCKER_IMAGE ?= "authentik:test"
GEN_API_TS = gen-ts-api GEN_API_TS = gen-ts-api

View File

@ -41,7 +41,7 @@ class VersionSerializer(PassiveSerializer):
return __version__ return __version__
version_in_cache = cache.get(VERSION_CACHE_KEY) version_in_cache = cache.get(VERSION_CACHE_KEY)
if not version_in_cache: # pragma: no cover if not version_in_cache: # pragma: no cover
update_latest_version.send() update_latest_version.delay()
return __version__ return __version__
return version_in_cache return version_in_cache

View File

@ -0,0 +1,57 @@
"""authentik administration overview"""
from socket import gethostname
from django.conf import settings
from drf_spectacular.utils import extend_schema, inline_serializer
from packaging.version import parse
from rest_framework.fields import BooleanField, CharField
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
from authentik import get_full_version
from authentik.rbac.permissions import HasPermission
from authentik.root.celery import CELERY_APP
class WorkerView(APIView):
"""Get currently connected worker count."""
permission_classes = [HasPermission("authentik_rbac.view_system_info")]
@extend_schema(
responses=inline_serializer(
"Worker",
fields={
"worker_id": CharField(),
"version": CharField(),
"version_matching": BooleanField(),
},
many=True,
)
)
def get(self, request: Request) -> Response:
"""Get currently connected worker count."""
raw: list[dict[str, dict]] = CELERY_APP.control.ping(timeout=0.5)
our_version = parse(get_full_version())
response = []
for worker in raw:
key = list(worker.keys())[0]
version = worker[key].get("version")
version_matching = False
if version:
version_matching = parse(version) == our_version
response.append(
{"worker_id": key, "version": version, "version_matching": version_matching}
)
# In debug we run with `task_always_eager`, so tasks are ran on the main process
if settings.DEBUG: # pragma: no cover
response.append(
{
"worker_id": f"authentik-debug@{gethostname()}",
"version": get_full_version(),
"version_matching": True,
}
)
return Response(response)

View File

@ -3,9 +3,6 @@
from prometheus_client import Info from prometheus_client import Info
from authentik.blueprints.apps import ManagedAppConfig from authentik.blueprints.apps import ManagedAppConfig
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand
from authentik.tasks.schedules.lib import ScheduleSpec
PROM_INFO = Info("authentik_version", "Currently running authentik version") PROM_INFO = Info("authentik_version", "Currently running authentik version")
@ -33,15 +30,3 @@ class AuthentikAdminConfig(ManagedAppConfig):
notification_version = notification.event.context["new_version"] notification_version = notification.event.context["new_version"]
if LOCAL_VERSION >= parse(notification_version): if LOCAL_VERSION >= parse(notification_version):
notification.delete() notification.delete()
@property
def global_schedule_specs(self) -> list[ScheduleSpec]:
from authentik.admin.tasks import update_latest_version
return [
ScheduleSpec(
actor=update_latest_version,
crontab=f"{fqdn_rand('admin_latest_version')} * * * *",
paused=CONFIG.get_bool("disable_update_check"),
),
]

View File

@ -0,0 +1,15 @@
"""authentik admin settings"""
from celery.schedules import crontab
from django_tenants.utils import get_public_schema_name
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"admin_latest_version": {
"task": "authentik.admin.tasks.update_latest_version",
"schedule": crontab(minute=fqdn_rand("admin_latest_version"), hour="*"),
"tenant_schemas": [get_public_schema_name()],
"options": {"queue": "authentik_scheduled"},
}
}

View File

@ -0,0 +1,35 @@
"""admin signals"""
from django.dispatch import receiver
from packaging.version import parse
from prometheus_client import Gauge
from authentik import get_full_version
from authentik.root.celery import CELERY_APP
from authentik.root.monitoring import monitoring_set
GAUGE_WORKERS = Gauge(
"authentik_admin_workers",
"Currently connected workers, their versions and if they are the same version as authentik",
["version", "version_matched"],
)
_version = parse(get_full_version())
@receiver(monitoring_set)
def monitoring_set_workers(sender, **kwargs):
"""Set worker gauge"""
raw: list[dict[str, dict]] = CELERY_APP.control.ping(timeout=0.5)
worker_version_count = {}
for worker in raw:
key = list(worker.keys())[0]
version = worker[key].get("version")
version_matching = False
if version:
version_matching = parse(version) == _version
worker_version_count.setdefault(version, {"count": 0, "matching": version_matching})
worker_version_count[version]["count"] += 1
for version, stats in worker_version_count.items():
GAUGE_WORKERS.labels(version, stats["matching"]).set(stats["count"])

View File

@ -2,8 +2,6 @@
from django.core.cache import cache from django.core.cache import cache
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq import actor
from packaging.version import parse from packaging.version import parse
from requests import RequestException from requests import RequestException
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -11,9 +9,10 @@ from structlog.stdlib import get_logger
from authentik import __version__, get_build_hash from authentik import __version__, get_build_hash
from authentik.admin.apps import PROM_INFO from authentik.admin.apps import PROM_INFO
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.lib.utils.http import get_http_session from authentik.lib.utils.http import get_http_session
from authentik.tasks.models import Task from authentik.root.celery import CELERY_APP
LOGGER = get_logger() LOGGER = get_logger()
VERSION_NULL = "0.0.0" VERSION_NULL = "0.0.0"
@ -33,12 +32,13 @@ def _set_prom_info():
) )
@actor(description=_("Update latest version info.")) @CELERY_APP.task(bind=True, base=SystemTask)
def update_latest_version(): @prefill_task
self: Task = CurrentTask.get_task() def update_latest_version(self: SystemTask):
"""Update latest version info"""
if CONFIG.get_bool("disable_update_check"): if CONFIG.get_bool("disable_update_check"):
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
self.info("Version check disabled.") self.set_status(TaskStatus.WARNING, "Version check disabled.")
return return
try: try:
response = get_http_session().get( response = get_http_session().get(
@ -48,7 +48,7 @@ def update_latest_version():
data = response.json() data = response.json()
upstream_version = data.get("stable", {}).get("version") upstream_version = data.get("stable", {}).get("version")
cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT) cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT)
self.info("Successfully updated latest Version") self.set_status(TaskStatus.SUCCESSFUL, "Successfully updated latest Version")
_set_prom_info() _set_prom_info()
# Check if upstream version is newer than what we're running, # Check if upstream version is newer than what we're running,
# and if no event exists yet, create one. # and if no event exists yet, create one.
@ -71,7 +71,7 @@ def update_latest_version():
).save() ).save()
except (RequestException, IndexError) as exc: except (RequestException, IndexError) as exc:
cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT)
raise exc self.set_error(exc)
_set_prom_info() _set_prom_info()

View File

@ -29,6 +29,13 @@ class TestAdminAPI(TestCase):
body = loads(response.content) body = loads(response.content)
self.assertEqual(body["version_current"], __version__) self.assertEqual(body["version_current"], __version__)
def test_workers(self):
"""Test Workers API"""
response = self.client.get(reverse("authentik_api:admin_workers"))
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertEqual(len(body), 0)
def test_apps(self): def test_apps(self):
"""Test apps API""" """Test apps API"""
response = self.client.get(reverse("authentik_api:apps-list")) response = self.client.get(reverse("authentik_api:apps-list"))

View File

@ -30,7 +30,7 @@ class TestAdminTasks(TestCase):
"""Test Update checker with valid response""" """Test Update checker with valid response"""
with Mocker() as mocker, CONFIG.patch("disable_update_check", False): with Mocker() as mocker, CONFIG.patch("disable_update_check", False):
mocker.get("https://version.goauthentik.io/version.json", json=RESPONSE_VALID) mocker.get("https://version.goauthentik.io/version.json", json=RESPONSE_VALID)
update_latest_version.send() update_latest_version.delay().get()
self.assertEqual(cache.get(VERSION_CACHE_KEY), "99999999.9999999") self.assertEqual(cache.get(VERSION_CACHE_KEY), "99999999.9999999")
self.assertTrue( self.assertTrue(
Event.objects.filter( Event.objects.filter(
@ -40,7 +40,7 @@ class TestAdminTasks(TestCase):
).exists() ).exists()
) )
# test that a consecutive check doesn't create a duplicate event # test that a consecutive check doesn't create a duplicate event
update_latest_version.send() update_latest_version.delay().get()
self.assertEqual( self.assertEqual(
len( len(
Event.objects.filter( Event.objects.filter(
@ -56,7 +56,7 @@ class TestAdminTasks(TestCase):
"""Test Update checker with invalid response""" """Test Update checker with invalid response"""
with Mocker() as mocker: with Mocker() as mocker:
mocker.get("https://version.goauthentik.io/version.json", status_code=400) mocker.get("https://version.goauthentik.io/version.json", status_code=400)
update_latest_version.send() update_latest_version.delay().get()
self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0")
self.assertFalse( self.assertFalse(
Event.objects.filter( Event.objects.filter(
@ -67,15 +67,14 @@ class TestAdminTasks(TestCase):
def test_version_disabled(self): def test_version_disabled(self):
"""Test Update checker while its disabled""" """Test Update checker while its disabled"""
with CONFIG.patch("disable_update_check", True): with CONFIG.patch("disable_update_check", True):
update_latest_version.send() update_latest_version.delay().get()
self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0")
def test_clear_update_notifications(self): def test_clear_update_notifications(self):
"""Test clear of previous notification""" """Test clear of previous notification"""
admin_config = apps.get_app_config("authentik_admin") admin_config = apps.get_app_config("authentik_admin")
Event.objects.create( Event.objects.create(
action=EventAction.UPDATE_AVAILABLE, action=EventAction.UPDATE_AVAILABLE, context={"new_version": "99999999.9999999.9999999"}
context={"new_version": "99999999.9999999.9999999"},
) )
Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={"new_version": "1.1.1"}) Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={"new_version": "1.1.1"})
Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={}) Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={})

View File

@ -6,11 +6,13 @@ from authentik.admin.api.meta import AppsViewSet, ModelViewSet
from authentik.admin.api.system import SystemView from authentik.admin.api.system import SystemView
from authentik.admin.api.version import VersionView from authentik.admin.api.version import VersionView
from authentik.admin.api.version_history import VersionHistoryViewSet from authentik.admin.api.version_history import VersionHistoryViewSet
from authentik.admin.api.workers import WorkerView
api_urlpatterns = [ api_urlpatterns = [
("admin/apps", AppsViewSet, "apps"), ("admin/apps", AppsViewSet, "apps"),
("admin/models", ModelViewSet, "models"), ("admin/models", ModelViewSet, "models"),
path("admin/version/", VersionView.as_view(), name="admin_version"), path("admin/version/", VersionView.as_view(), name="admin_version"),
("admin/version/history", VersionHistoryViewSet, "version_history"), ("admin/version/history", VersionHistoryViewSet, "version_history"),
path("admin/workers/", WorkerView.as_view(), name="admin_workers"),
path("admin/system/", SystemView.as_view(), name="admin_system"), path("admin/system/", SystemView.as_view(), name="admin_system"),
] ]

View File

@ -39,7 +39,7 @@ class BlueprintInstanceSerializer(ModelSerializer):
"""Ensure the path (if set) specified is retrievable""" """Ensure the path (if set) specified is retrievable"""
if path == "" or path.startswith(OCI_PREFIX): if path == "" or path.startswith(OCI_PREFIX):
return path return path
files: list[dict] = blueprints_find_dict.send().get_result(block=True) files: list[dict] = blueprints_find_dict.delay().get()
if path not in [file["path"] for file in files]: if path not in [file["path"] for file in files]:
raise ValidationError(_("Blueprint file does not exist")) raise ValidationError(_("Blueprint file does not exist"))
return path return path
@ -115,7 +115,7 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet):
@action(detail=False, pagination_class=None, filter_backends=[]) @action(detail=False, pagination_class=None, filter_backends=[])
def available(self, request: Request) -> Response: def available(self, request: Request) -> Response:
"""Get blueprints""" """Get blueprints"""
files: list[dict] = blueprints_find_dict.send().get_result(block=True) files: list[dict] = blueprints_find_dict.delay().get()
return Response(files) return Response(files)
@permission_required("authentik_blueprints.view_blueprintinstance") @permission_required("authentik_blueprints.view_blueprintinstance")
@ -129,5 +129,5 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet):
def apply(self, request: Request, *args, **kwargs) -> Response: def apply(self, request: Request, *args, **kwargs) -> Response:
"""Apply a blueprint""" """Apply a blueprint"""
blueprint = self.get_object() blueprint = self.get_object()
apply_blueprint.send_with_options(args=(blueprint.pk,), rel_obj=blueprint) apply_blueprint.delay(str(blueprint.pk)).get()
return self.retrieve(request, *args, **kwargs) return self.retrieve(request, *args, **kwargs)

View File

@ -6,12 +6,9 @@ from inspect import ismethod
from django.apps import AppConfig from django.apps import AppConfig
from django.db import DatabaseError, InternalError, ProgrammingError from django.db import DatabaseError, InternalError, ProgrammingError
from dramatiq.broker import get_broker
from structlog.stdlib import BoundLogger, get_logger from structlog.stdlib import BoundLogger, get_logger
from authentik.lib.utils.time import fqdn_rand
from authentik.root.signals import startup from authentik.root.signals import startup
from authentik.tasks.schedules.lib import ScheduleSpec
class ManagedAppConfig(AppConfig): class ManagedAppConfig(AppConfig):
@ -37,7 +34,7 @@ class ManagedAppConfig(AppConfig):
def import_related(self): def import_related(self):
"""Automatically import related modules which rely on just being imported """Automatically import related modules which rely on just being imported
to register themselves (mainly django signals and tasks)""" to register themselves (mainly django signals and celery tasks)"""
def import_relative(rel_module: str): def import_relative(rel_module: str):
try: try:
@ -83,16 +80,6 @@ class ManagedAppConfig(AppConfig):
func._authentik_managed_reconcile = ManagedAppConfig.RECONCILE_GLOBAL_CATEGORY func._authentik_managed_reconcile = ManagedAppConfig.RECONCILE_GLOBAL_CATEGORY
return func return func
@property
def tenant_schedule_specs(self) -> list[ScheduleSpec]:
"""Get a list of schedule specs that must exist in each tenant"""
return []
@property
def global_schedule_specs(self) -> list[ScheduleSpec]:
"""Get a list of schedule specs that must exist in the default tenant"""
return []
def _reconcile_tenant(self) -> None: def _reconcile_tenant(self) -> None:
"""reconcile ourselves for tenanted methods""" """reconcile ourselves for tenanted methods"""
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
@ -113,12 +100,8 @@ class ManagedAppConfig(AppConfig):
""" """
from django_tenants.utils import get_public_schema_name, schema_context from django_tenants.utils import get_public_schema_name, schema_context
try: with schema_context(get_public_schema_name()):
with schema_context(get_public_schema_name()): self._reconcile(self.RECONCILE_GLOBAL_CATEGORY)
self._reconcile(self.RECONCILE_GLOBAL_CATEGORY)
except (DatabaseError, ProgrammingError, InternalError) as exc:
self.logger.debug("Failed to access database to run reconcile", exc=exc)
return
class AuthentikBlueprintsConfig(ManagedAppConfig): class AuthentikBlueprintsConfig(ManagedAppConfig):
@ -129,29 +112,19 @@ class AuthentikBlueprintsConfig(ManagedAppConfig):
verbose_name = "authentik Blueprints" verbose_name = "authentik Blueprints"
default = True default = True
@ManagedAppConfig.reconcile_global
def load_blueprints_v1_tasks(self):
"""Load v1 tasks"""
self.import_module("authentik.blueprints.v1.tasks")
@ManagedAppConfig.reconcile_tenant
def blueprints_discovery(self):
"""Run blueprint discovery"""
from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints
blueprints_discovery.delay()
clear_failed_blueprints.delay()
def import_models(self): def import_models(self):
super().import_models() super().import_models()
self.import_module("authentik.blueprints.v1.meta.apply_blueprint") self.import_module("authentik.blueprints.v1.meta.apply_blueprint")
@ManagedAppConfig.reconcile_global
def tasks_middlewares(self):
from authentik.blueprints.v1.tasks import BlueprintWatcherMiddleware
get_broker().add_middleware(BlueprintWatcherMiddleware())
@property
def tenant_schedule_specs(self) -> list[ScheduleSpec]:
from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints
return [
ScheduleSpec(
actor=blueprints_discovery,
crontab=f"{fqdn_rand('blueprints_v1_discover')} * * * *",
send_on_startup=True,
),
ScheduleSpec(
actor=clear_failed_blueprints,
crontab=f"{fqdn_rand('blueprints_v1_cleanup')} * * * *",
send_on_startup=True,
),
]

View File

@ -3,7 +3,6 @@
from pathlib import Path from pathlib import Path
from uuid import uuid4 from uuid import uuid4
from django.contrib.contenttypes.fields import GenericRelation
from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields import ArrayField
from django.db import models from django.db import models
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -72,13 +71,6 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel):
enabled = models.BooleanField(default=True) enabled = models.BooleanField(default=True)
managed_models = ArrayField(models.TextField(), default=list) managed_models = ArrayField(models.TextField(), default=list)
# Manual link to tasks instead of using TasksModel because of loop imports
tasks = GenericRelation(
"authentik_tasks.Task",
content_type_field="rel_obj_content_type",
object_id_field="rel_obj_id",
)
class Meta: class Meta:
verbose_name = _("Blueprint Instance") verbose_name = _("Blueprint Instance")
verbose_name_plural = _("Blueprint Instances") verbose_name_plural = _("Blueprint Instances")

View File

@ -0,0 +1,18 @@
"""blueprint Settings"""
from celery.schedules import crontab
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"blueprints_v1_discover": {
"task": "authentik.blueprints.v1.tasks.blueprints_discovery",
"schedule": crontab(minute=fqdn_rand("blueprints_v1_discover"), hour="*"),
"options": {"queue": "authentik_scheduled"},
},
"blueprints_v1_cleanup": {
"task": "authentik.blueprints.v1.tasks.clear_failed_blueprints",
"schedule": crontab(minute=fqdn_rand("blueprints_v1_cleanup"), hour="*"),
"options": {"queue": "authentik_scheduled"},
},
}

View File

@ -1,2 +0,0 @@
# Import all v1 tasks for auto task discovery
from authentik.blueprints.v1.tasks import * # noqa: F403

View File

@ -54,7 +54,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase):
file.seek(0) file.seek(0)
file_hash = sha512(file.read().encode()).hexdigest() file_hash = sha512(file.read().encode()).hexdigest()
file.flush() file.flush()
blueprints_discovery.send() blueprints_discovery()
instance = BlueprintInstance.objects.filter(name=blueprint_id).first() instance = BlueprintInstance.objects.filter(name=blueprint_id).first()
self.assertEqual(instance.last_applied_hash, file_hash) self.assertEqual(instance.last_applied_hash, file_hash)
self.assertEqual( self.assertEqual(
@ -82,7 +82,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase):
) )
) )
file.flush() file.flush()
blueprints_discovery.send() blueprints_discovery()
blueprint = BlueprintInstance.objects.filter(name="foo").first() blueprint = BlueprintInstance.objects.filter(name="foo").first()
self.assertEqual( self.assertEqual(
blueprint.last_applied_hash, blueprint.last_applied_hash,
@ -107,7 +107,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase):
) )
) )
file.flush() file.flush()
blueprints_discovery.send() blueprints_discovery()
blueprint.refresh_from_db() blueprint.refresh_from_db()
self.assertEqual( self.assertEqual(
blueprint.last_applied_hash, blueprint.last_applied_hash,

View File

@ -57,6 +57,7 @@ from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import (
EndpointDeviceConnection, EndpointDeviceConnection,
) )
from authentik.events.logs import LogEvent, capture_logs from authentik.events.logs import LogEvent, capture_logs
from authentik.events.models import SystemTask
from authentik.events.utils import cleanse_dict from authentik.events.utils import cleanse_dict
from authentik.flows.models import FlowToken, Stage from authentik.flows.models import FlowToken, Stage
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
@ -76,7 +77,6 @@ from authentik.providers.scim.models import SCIMProviderGroup, SCIMProviderUser
from authentik.rbac.models import Role from authentik.rbac.models import Role
from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser
from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType
from authentik.tasks.models import Task
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
# Context set when the serializer is created in a blueprint context # Context set when the serializer is created in a blueprint context
@ -118,7 +118,7 @@ def excluded_models() -> list[type[Model]]:
SCIMProviderGroup, SCIMProviderGroup,
SCIMProviderUser, SCIMProviderUser,
Tenant, Tenant,
Task, SystemTask,
ConnectionToken, ConnectionToken,
AuthorizationCode, AuthorizationCode,
AccessToken, AccessToken,

View File

@ -44,7 +44,7 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer):
return MetaResult() return MetaResult()
LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance) LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance)
apply_blueprint(self.blueprint_instance.pk) apply_blueprint(str(self.blueprint_instance.pk))
return MetaResult() return MetaResult()

View File

@ -4,17 +4,12 @@ from dataclasses import asdict, dataclass, field
from hashlib import sha512 from hashlib import sha512
from pathlib import Path from pathlib import Path
from sys import platform from sys import platform
from uuid import UUID
from dacite.core import from_dict from dacite.core import from_dict
from django.conf import settings
from django.db import DatabaseError, InternalError, ProgrammingError from django.db import DatabaseError, InternalError, ProgrammingError
from django.utils.text import slugify from django.utils.text import slugify
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask, CurrentTaskNotFound
from dramatiq.actor import actor
from dramatiq.middleware import Middleware
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from watchdog.events import ( from watchdog.events import (
FileCreatedEvent, FileCreatedEvent,
@ -36,13 +31,15 @@ from authentik.blueprints.v1.importer import Importer
from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE
from authentik.blueprints.v1.oci import OCI_PREFIX from authentik.blueprints.v1.oci import OCI_PREFIX
from authentik.events.logs import capture_logs from authentik.events.logs import capture_logs
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.events.utils import sanitize_dict from authentik.events.utils import sanitize_dict
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.tasks.models import Task from authentik.root.celery import CELERY_APP
from authentik.tasks.schedules.models import Schedule
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
LOGGER = get_logger() LOGGER = get_logger()
_file_watcher_started = False
@dataclass @dataclass
@ -56,21 +53,22 @@ class BlueprintFile:
meta: BlueprintMetadata | None = field(default=None) meta: BlueprintMetadata | None = field(default=None)
class BlueprintWatcherMiddleware(Middleware): def start_blueprint_watcher():
def start_blueprint_watcher(self): """Start blueprint watcher, if it's not running already."""
"""Start blueprint watcher""" # This function might be called twice since it's called on celery startup
observer = Observer()
kwargs = {}
if platform.startswith("linux"):
kwargs["event_filter"] = (FileCreatedEvent, FileModifiedEvent)
observer.schedule(
BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs
)
observer.start()
def after_worker_boot(self, broker, worker): global _file_watcher_started # noqa: PLW0603
if not settings.TEST: if _file_watcher_started:
self.start_blueprint_watcher() return
observer = Observer()
kwargs = {}
if platform.startswith("linux"):
kwargs["event_filter"] = (FileCreatedEvent, FileModifiedEvent)
observer.schedule(
BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs
)
observer.start()
_file_watcher_started = True
class BlueprintEventHandler(FileSystemEventHandler): class BlueprintEventHandler(FileSystemEventHandler):
@ -94,7 +92,7 @@ class BlueprintEventHandler(FileSystemEventHandler):
LOGGER.debug("new blueprint file created, starting discovery") LOGGER.debug("new blueprint file created, starting discovery")
for tenant in Tenant.objects.filter(ready=True): for tenant in Tenant.objects.filter(ready=True):
with tenant: with tenant:
Schedule.dispatch_by_actor(blueprints_discovery) blueprints_discovery.delay()
def on_modified(self, event: FileSystemEvent): def on_modified(self, event: FileSystemEvent):
"""Process file modification""" """Process file modification"""
@ -105,14 +103,14 @@ class BlueprintEventHandler(FileSystemEventHandler):
with tenant: with tenant:
for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True): for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True):
LOGGER.debug("modified blueprint file, starting apply", instance=instance) LOGGER.debug("modified blueprint file, starting apply", instance=instance)
apply_blueprint.send_with_options(args=(instance.pk,), rel_obj=instance) apply_blueprint.delay(instance.pk.hex)
@actor( @CELERY_APP.task(
description=_("Find blueprints as `blueprints_find` does, but return a safe dict."),
throws=(DatabaseError, ProgrammingError, InternalError), throws=(DatabaseError, ProgrammingError, InternalError),
) )
def blueprints_find_dict(): def blueprints_find_dict():
"""Find blueprints as `blueprints_find` does, but return a safe dict"""
blueprints = [] blueprints = []
for blueprint in blueprints_find(): for blueprint in blueprints_find():
blueprints.append(sanitize_dict(asdict(blueprint))) blueprints.append(sanitize_dict(asdict(blueprint)))
@ -148,19 +146,21 @@ def blueprints_find() -> list[BlueprintFile]:
return blueprints return blueprints
@actor( @CELERY_APP.task(
description=_("Find blueprints and check if they need to be created in the database."), throws=(DatabaseError, ProgrammingError, InternalError), base=SystemTask, bind=True
throws=(DatabaseError, ProgrammingError, InternalError),
) )
def blueprints_discovery(path: str | None = None): @prefill_task
self: Task = CurrentTask.get_task() def blueprints_discovery(self: SystemTask, path: str | None = None):
"""Find blueprints and check if they need to be created in the database"""
count = 0 count = 0
for blueprint in blueprints_find(): for blueprint in blueprints_find():
if path and blueprint.path != path: if path and blueprint.path != path:
continue continue
check_blueprint_v1_file(blueprint) check_blueprint_v1_file(blueprint)
count += 1 count += 1
self.info(f"Successfully imported {count} files.") self.set_status(
TaskStatus.SUCCESSFUL, _("Successfully imported {count} files.".format(count=count))
)
def check_blueprint_v1_file(blueprint: BlueprintFile): def check_blueprint_v1_file(blueprint: BlueprintFile):
@ -187,26 +187,22 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
) )
if instance.last_applied_hash != blueprint.hash: if instance.last_applied_hash != blueprint.hash:
LOGGER.info("Applying blueprint due to changed file", instance=instance, path=instance.path) LOGGER.info("Applying blueprint due to changed file", instance=instance, path=instance.path)
apply_blueprint.send_with_options(args=(instance.pk,), rel_obj=instance) apply_blueprint.delay(str(instance.pk))
@actor(description=_("Apply single blueprint.")) @CELERY_APP.task(
def apply_blueprint(instance_pk: UUID): bind=True,
try: base=SystemTask,
self: Task = CurrentTask.get_task() )
except CurrentTaskNotFound: def apply_blueprint(self: SystemTask, instance_pk: str):
self = Task() """Apply single blueprint"""
self.set_uid(str(instance_pk)) self.save_on_success = False
instance: BlueprintInstance | None = None instance: BlueprintInstance | None = None
try: try:
instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first()
if not instance: if not instance or not instance.enabled:
self.warning(f"Could not find blueprint {instance_pk}, skipping")
return return
self.set_uid(slugify(instance.name)) self.set_uid(slugify(instance.name))
if not instance.enabled:
self.info(f"Blueprint {instance.name} is disabled, skipping")
return
blueprint_content = instance.retrieve() blueprint_content = instance.retrieve()
file_hash = sha512(blueprint_content.encode()).hexdigest() file_hash = sha512(blueprint_content.encode()).hexdigest()
importer = Importer.from_string(blueprint_content, instance.context) importer = Importer.from_string(blueprint_content, instance.context)
@ -216,18 +212,19 @@ def apply_blueprint(instance_pk: UUID):
if not valid: if not valid:
instance.status = BlueprintInstanceStatus.ERROR instance.status = BlueprintInstanceStatus.ERROR
instance.save() instance.save()
self.logs(logs) self.set_status(TaskStatus.ERROR, *logs)
return return
with capture_logs() as logs: with capture_logs() as logs:
applied = importer.apply() applied = importer.apply()
if not applied: if not applied:
instance.status = BlueprintInstanceStatus.ERROR instance.status = BlueprintInstanceStatus.ERROR
instance.save() instance.save()
self.logs(logs) self.set_status(TaskStatus.ERROR, *logs)
return return
instance.status = BlueprintInstanceStatus.SUCCESSFUL instance.status = BlueprintInstanceStatus.SUCCESSFUL
instance.last_applied_hash = file_hash instance.last_applied_hash = file_hash
instance.last_applied = now() instance.last_applied = now()
self.set_status(TaskStatus.SUCCESSFUL)
except ( except (
OSError, OSError,
DatabaseError, DatabaseError,
@ -238,14 +235,15 @@ def apply_blueprint(instance_pk: UUID):
) as exc: ) as exc:
if instance: if instance:
instance.status = BlueprintInstanceStatus.ERROR instance.status = BlueprintInstanceStatus.ERROR
self.error(exc) self.set_error(exc)
finally: finally:
if instance: if instance:
instance.save() instance.save()
@actor(description=_("Remove blueprints which couldn't be fetched.")) @CELERY_APP.task()
def clear_failed_blueprints(): def clear_failed_blueprints():
"""Remove blueprints which couldn't be fetched"""
# Exclude OCI blueprints as those might be temporarily unavailable # Exclude OCI blueprints as those might be temporarily unavailable
for blueprint in BlueprintInstance.objects.exclude(path__startswith=OCI_PREFIX): for blueprint in BlueprintInstance.objects.exclude(path__startswith=OCI_PREFIX):
try: try:

View File

@ -9,7 +9,6 @@ class AuthentikBrandsConfig(ManagedAppConfig):
name = "authentik.brands" name = "authentik.brands"
label = "authentik_brands" label = "authentik_brands"
verbose_name = "authentik Brands" verbose_name = "authentik Brands"
default = True
mountpoints = { mountpoints = {
"authentik.brands.urls_root": "", "authentik.brands.urls_root": "",
} }

View File

@ -1,7 +1,8 @@
"""authentik core app config""" """authentik core app config"""
from django.conf import settings
from authentik.blueprints.apps import ManagedAppConfig from authentik.blueprints.apps import ManagedAppConfig
from authentik.tasks.schedules.lib import ScheduleSpec
class AuthentikCoreConfig(ManagedAppConfig): class AuthentikCoreConfig(ManagedAppConfig):
@ -13,6 +14,14 @@ class AuthentikCoreConfig(ManagedAppConfig):
mountpoint = "" mountpoint = ""
default = True default = True
@ManagedAppConfig.reconcile_global
def debug_worker_hook(self):
"""Dispatch startup tasks inline when debugging"""
if settings.DEBUG:
from authentik.root.celery import worker_ready_hook
worker_ready_hook()
@ManagedAppConfig.reconcile_tenant @ManagedAppConfig.reconcile_tenant
def source_inbuilt(self): def source_inbuilt(self):
"""Reconcile inbuilt source""" """Reconcile inbuilt source"""
@ -25,18 +34,3 @@ class AuthentikCoreConfig(ManagedAppConfig):
}, },
managed=Source.MANAGED_INBUILT, managed=Source.MANAGED_INBUILT,
) )
@property
def tenant_schedule_specs(self) -> list[ScheduleSpec]:
from authentik.core.tasks import clean_expired_models, clean_temporary_users
return [
ScheduleSpec(
actor=clean_expired_models,
crontab="2-59/5 * * * *",
),
ScheduleSpec(
actor=clean_temporary_users,
crontab="9-59/5 * * * *",
),
]

View File

@ -0,0 +1,21 @@
"""Run bootstrap tasks"""
from django.core.management.base import BaseCommand
from django_tenants.utils import get_public_schema_name
from authentik.root.celery import _get_startup_tasks_all_tenants, _get_startup_tasks_default_tenant
from authentik.tenants.models import Tenant
class Command(BaseCommand):
"""Run bootstrap tasks to ensure certain objects are created"""
def handle(self, **options):
for task in _get_startup_tasks_default_tenant():
with Tenant.objects.get(schema_name=get_public_schema_name()):
task()
for task in _get_startup_tasks_all_tenants():
for tenant in Tenant.objects.filter(ready=True):
with tenant:
task()

View File

@ -0,0 +1,47 @@
"""Run worker"""
from sys import exit as sysexit
from tempfile import tempdir
from celery.apps.worker import Worker
from django.core.management.base import BaseCommand
from django.db import close_old_connections
from structlog.stdlib import get_logger
from authentik.lib.config import CONFIG
from authentik.lib.debug import start_debug_server
from authentik.root.celery import CELERY_APP
LOGGER = get_logger()
class Command(BaseCommand):
"""Run worker"""
def add_arguments(self, parser):
parser.add_argument(
"-b",
"--beat",
action="store_false",
help="When set, this worker will _not_ run Beat (scheduled) tasks",
)
def handle(self, **options):
LOGGER.debug("Celery options", **options)
close_old_connections()
start_debug_server()
worker: Worker = CELERY_APP.Worker(
no_color=False,
quiet=True,
optimization="fair",
autoscale=(CONFIG.get_int("worker.concurrency"), 1),
task_events=True,
beat=options.get("beat", True),
schedule_filename=f"{tempdir}/celerybeat-schedule",
queues=["authentik", "authentik_scheduled", "authentik_events"],
)
for task in CELERY_APP.tasks:
LOGGER.debug("Registered task", task=task)
worker.start()
sysexit(worker.exitcode)

View File

@ -3,9 +3,6 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.models import ( from authentik.core.models import (
@ -14,14 +11,17 @@ from authentik.core.models import (
ExpiringModel, ExpiringModel,
User, User,
) )
from authentik.tasks.models import Task from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
from authentik.root.celery import CELERY_APP
LOGGER = get_logger() LOGGER = get_logger()
@actor(description=_("Remove expired objects.")) @CELERY_APP.task(bind=True, base=SystemTask)
def clean_expired_models(): @prefill_task
self: Task = CurrentTask.get_task() def clean_expired_models(self: SystemTask):
"""Remove expired objects"""
messages = []
for cls in ExpiringModel.__subclasses__(): for cls in ExpiringModel.__subclasses__():
cls: ExpiringModel cls: ExpiringModel
objects = ( objects = (
@ -31,13 +31,16 @@ def clean_expired_models():
for obj in objects: for obj in objects:
obj.expire_action() obj.expire_action()
LOGGER.debug("Expired models", model=cls, amount=amount) LOGGER.debug("Expired models", model=cls, amount=amount)
self.info(f"Expired {amount} {cls._meta.verbose_name_plural}") messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}")
self.set_status(TaskStatus.SUCCESSFUL, *messages)
@actor(description=_("Remove temporary users created by SAML Sources.")) @CELERY_APP.task(bind=True, base=SystemTask)
def clean_temporary_users(): @prefill_task
self: Task = CurrentTask.get_task() def clean_temporary_users(self: SystemTask):
"""Remove temporary users created by SAML Sources"""
_now = datetime.now() _now = datetime.now()
messages = []
deleted_users = 0 deleted_users = 0
for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}): for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}):
if not user.attributes.get(USER_ATTRIBUTE_EXPIRES): if not user.attributes.get(USER_ATTRIBUTE_EXPIRES):
@ -49,4 +52,5 @@ def clean_temporary_users():
LOGGER.debug("User is expired and will be deleted.", user=user, delta=delta) LOGGER.debug("User is expired and will be deleted.", user=user, delta=delta)
user.delete() user.delete()
deleted_users += 1 deleted_users += 1
self.info(f"Successfully deleted {deleted_users} users.") messages.append(f"Successfully deleted {deleted_users} users.")
self.set_status(TaskStatus.SUCCESSFUL, *messages)

View File

@ -36,7 +36,7 @@ class TestTasks(APITestCase):
expires=now(), user=get_anonymous_user(), intent=TokenIntents.INTENT_API expires=now(), user=get_anonymous_user(), intent=TokenIntents.INTENT_API
) )
key = token.key key = token.key
clean_expired_models.send() clean_expired_models.delay().get()
token.refresh_from_db() token.refresh_from_db()
self.assertNotEqual(key, token.key) self.assertNotEqual(key, token.key)
@ -50,5 +50,5 @@ class TestTasks(APITestCase):
USER_ATTRIBUTE_EXPIRES: mktime(now().timetuple()), USER_ATTRIBUTE_EXPIRES: mktime(now().timetuple()),
}, },
) )
clean_temporary_users.send() clean_temporary_users.delay().get()
self.assertFalse(User.objects.filter(username=username)) self.assertFalse(User.objects.filter(username=username))

View File

@ -4,8 +4,6 @@ from datetime import UTC, datetime
from authentik.blueprints.apps import ManagedAppConfig from authentik.blueprints.apps import ManagedAppConfig
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.lib.utils.time import fqdn_rand
from authentik.tasks.schedules.lib import ScheduleSpec
MANAGED_KEY = "goauthentik.io/crypto/jwt-managed" MANAGED_KEY = "goauthentik.io/crypto/jwt-managed"
@ -69,14 +67,3 @@ class AuthentikCryptoConfig(ManagedAppConfig):
"key_data": builder.private_key, "key_data": builder.private_key,
}, },
) )
@property
def tenant_schedule_specs(self) -> list[ScheduleSpec]:
from authentik.crypto.tasks import certificate_discovery
return [
ScheduleSpec(
actor=certificate_discovery,
crontab=f"{fqdn_rand('crypto_certificate_discovery')} * * * *",
),
]

View File

@ -0,0 +1,13 @@
"""Crypto task Settings"""
from celery.schedules import crontab
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"crypto_certificate_discovery": {
"task": "authentik.crypto.tasks.certificate_discovery",
"schedule": crontab(minute=fqdn_rand("crypto_certificate_discovery"), hour="*"),
"options": {"queue": "authentik_scheduled"},
},
}

View File

@ -7,13 +7,13 @@ from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_pem_private_key from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.x509.base import load_pem_x509_certificate from cryptography.x509.base import load_pem_x509_certificate
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.tasks.models import Task from authentik.root.celery import CELERY_APP
LOGGER = get_logger() LOGGER = get_logger()
@ -36,9 +36,10 @@ def ensure_certificate_valid(body: str):
return body return body
@actor(description=_("Discover, import and update certificates from the filesystem.")) @CELERY_APP.task(bind=True, base=SystemTask)
def certificate_discovery(): @prefill_task
self: Task = CurrentTask.get_task() def certificate_discovery(self: SystemTask):
"""Discover, import and update certificates from the filesystem"""
certs = {} certs = {}
private_keys = {} private_keys = {}
discovered = 0 discovered = 0
@ -83,4 +84,6 @@ def certificate_discovery():
dirty = True dirty = True
if dirty: if dirty:
cert.save() cert.save()
self.info(f"Successfully imported {discovered} files.") self.set_status(
TaskStatus.SUCCESSFUL, _("Successfully imported {count} files.".format(count=discovered))
)

View File

@ -338,7 +338,7 @@ class TestCrypto(APITestCase):
with open(f"{temp_dir}/foo.bar/privkey.pem", "w+", encoding="utf-8") as _key: with open(f"{temp_dir}/foo.bar/privkey.pem", "w+", encoding="utf-8") as _key:
_key.write(builder.private_key) _key.write(builder.private_key)
with CONFIG.patch("cert_discovery_dir", temp_dir): with CONFIG.patch("cert_discovery_dir", temp_dir):
certificate_discovery.send() certificate_discovery()
keypair: CertificateKeyPair = CertificateKeyPair.objects.filter( keypair: CertificateKeyPair = CertificateKeyPair.objects.filter(
managed=MANAGED_DISCOVERED % "foo" managed=MANAGED_DISCOVERED % "foo"
).first() ).first()

View File

@ -3,8 +3,6 @@
from django.conf import settings from django.conf import settings
from authentik.blueprints.apps import ManagedAppConfig from authentik.blueprints.apps import ManagedAppConfig
from authentik.lib.utils.time import fqdn_rand
from authentik.tasks.schedules.lib import ScheduleSpec
class EnterpriseConfig(ManagedAppConfig): class EnterpriseConfig(ManagedAppConfig):
@ -28,14 +26,3 @@ class AuthentikEnterpriseConfig(EnterpriseConfig):
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
return LicenseKey.cached_summary().status.is_valid return LicenseKey.cached_summary().status.is_valid
@property
def tenant_schedule_specs(self) -> list[ScheduleSpec]:
from authentik.enterprise.tasks import enterprise_update_usage
return [
ScheduleSpec(
actor=enterprise_update_usage,
crontab=f"{fqdn_rand('enterprise_update_usage')} */2 * * *",
),
]

View File

@ -1,8 +1,6 @@
"""authentik Unique Password policy app config""" """authentik Unique Password policy app config"""
from authentik.enterprise.apps import EnterpriseConfig from authentik.enterprise.apps import EnterpriseConfig
from authentik.lib.utils.time import fqdn_rand
from authentik.tasks.schedules.lib import ScheduleSpec
class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig): class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig):
@ -10,21 +8,3 @@ class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig):
label = "authentik_policies_unique_password" label = "authentik_policies_unique_password"
verbose_name = "authentik Enterprise.Policies.Unique Password" verbose_name = "authentik Enterprise.Policies.Unique Password"
default = True default = True
@property
def tenant_schedule_specs(self) -> list[ScheduleSpec]:
from authentik.enterprise.policies.unique_password.tasks import (
check_and_purge_password_history,
trim_password_histories,
)
return [
ScheduleSpec(
actor=trim_password_histories,
crontab=f"{fqdn_rand('policies_unique_password_trim')} */12 * * *",
),
ScheduleSpec(
actor=check_and_purge_password_history,
crontab=f"{fqdn_rand('policies_unique_password_purge')} */24 * * *",
),
]

View File

@ -0,0 +1,20 @@
"""Unique Password Policy settings"""
from celery.schedules import crontab
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"policies_unique_password_trim_history": {
"task": "authentik.enterprise.policies.unique_password.tasks.trim_password_histories",
"schedule": crontab(minute=fqdn_rand("policies_unique_password_trim"), hour="*/12"),
"options": {"queue": "authentik_scheduled"},
},
"policies_unique_password_check_purge": {
"task": (
"authentik.enterprise.policies.unique_password.tasks.check_and_purge_password_history"
),
"schedule": crontab(minute=fqdn_rand("policies_unique_password_purge"), hour="*/24"),
"options": {"queue": "authentik_scheduled"},
},
}

View File

@ -1,37 +1,35 @@
from django.db.models.aggregates import Count from django.db.models.aggregates import Count
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from structlog import get_logger from structlog import get_logger
from authentik.enterprise.policies.unique_password.models import ( from authentik.enterprise.policies.unique_password.models import (
UniquePasswordPolicy, UniquePasswordPolicy,
UserPasswordHistory, UserPasswordHistory,
) )
from authentik.tasks.models import Task from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
from authentik.root.celery import CELERY_APP
LOGGER = get_logger() LOGGER = get_logger()
@actor( @CELERY_APP.task(bind=True, base=SystemTask)
description=_( @prefill_task
"Check if any UniquePasswordPolicy exists, and if not, purge the password history table." def check_and_purge_password_history(self: SystemTask):
) """Check if any UniquePasswordPolicy exists, and if not, purge the password history table.
) This is run on a schedule instead of being triggered by policy binding deletion.
def check_and_purge_password_history(): """
self: Task = CurrentTask.get_task()
if not UniquePasswordPolicy.objects.exists(): if not UniquePasswordPolicy.objects.exists():
UserPasswordHistory.objects.all().delete() UserPasswordHistory.objects.all().delete()
LOGGER.debug("Purged UserPasswordHistory table as no policies are in use") LOGGER.debug("Purged UserPasswordHistory table as no policies are in use")
self.info("Successfully purged UserPasswordHistory") self.set_status(TaskStatus.SUCCESSFUL, "Successfully purged UserPasswordHistory")
return return
self.info("Not purging password histories, a unique password policy exists") self.set_status(
TaskStatus.SUCCESSFUL, "Not purging password histories, a unique password policy exists"
)
@actor(description=_("Remove user password history that are too old.")) @CELERY_APP.task(bind=True, base=SystemTask)
def trim_password_histories(): def trim_password_histories(self: SystemTask):
"""Removes rows from UserPasswordHistory older than """Removes rows from UserPasswordHistory older than
the `n` most recent entries. the `n` most recent entries.
@ -39,8 +37,6 @@ def trim_password_histories():
UniquePasswordPolicy policies. UniquePasswordPolicy policies.
""" """
self: Task = CurrentTask.get_task()
# No policy, we'll let the cleanup above do its thing # No policy, we'll let the cleanup above do its thing
if not UniquePasswordPolicy.objects.exists(): if not UniquePasswordPolicy.objects.exists():
return return
@ -67,4 +63,4 @@ def trim_password_histories():
num_deleted, _ = UserPasswordHistory.objects.exclude(pk__in=all_pks_to_keep).delete() num_deleted, _ = UserPasswordHistory.objects.exclude(pk__in=all_pks_to_keep).delete()
LOGGER.debug("Deleted stale password history records", count=num_deleted) LOGGER.debug("Deleted stale password history records", count=num_deleted)
self.info(f"Delete {num_deleted} stale password history records") self.set_status(TaskStatus.SUCCESSFUL, f"Delete {num_deleted} stale password history records")

View File

@ -76,7 +76,7 @@ class TestCheckAndPurgePasswordHistory(TestCase):
self.assertTrue(UserPasswordHistory.objects.exists()) self.assertTrue(UserPasswordHistory.objects.exists())
# Run the task - should purge since no policy is in use # Run the task - should purge since no policy is in use
check_and_purge_password_history.send() check_and_purge_password_history()
# Verify the table is empty # Verify the table is empty
self.assertFalse(UserPasswordHistory.objects.exists()) self.assertFalse(UserPasswordHistory.objects.exists())
@ -99,7 +99,7 @@ class TestCheckAndPurgePasswordHistory(TestCase):
self.assertTrue(UserPasswordHistory.objects.exists()) self.assertTrue(UserPasswordHistory.objects.exists())
# Run the task - should NOT purge since a policy is in use # Run the task - should NOT purge since a policy is in use
check_and_purge_password_history.send() check_and_purge_password_history()
# Verify the entries still exist # Verify the entries still exist
self.assertTrue(UserPasswordHistory.objects.exists()) self.assertTrue(UserPasswordHistory.objects.exists())
@ -142,7 +142,7 @@ class TestTrimPasswordHistory(TestCase):
enabled=True, enabled=True,
order=0, order=0,
) )
trim_password_histories.send() trim_password_histories.delay()
user_pwd_history_qs = UserPasswordHistory.objects.filter(user=self.user) user_pwd_history_qs = UserPasswordHistory.objects.filter(user=self.user)
self.assertEqual(len(user_pwd_history_qs), 1) self.assertEqual(len(user_pwd_history_qs), 1)
@ -159,7 +159,7 @@ class TestTrimPasswordHistory(TestCase):
enabled=False, enabled=False,
order=0, order=0,
) )
trim_password_histories.send() trim_password_histories.delay()
self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists()) self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists())
def test_trim_password_history_fewer_records_than_maximum_is_no_op(self): def test_trim_password_history_fewer_records_than_maximum_is_no_op(self):
@ -174,5 +174,5 @@ class TestTrimPasswordHistory(TestCase):
enabled=True, enabled=True,
order=0, order=0,
) )
trim_password_histories.send() trim_password_histories.delay()
self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists()) self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists())

View File

@ -55,5 +55,5 @@ class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixi
] ]
search_fields = ["name"] search_fields = ["name"]
ordering = ["name"] ordering = ["name"]
sync_task = google_workspace_sync sync_single_task = google_workspace_sync
sync_objects_task = google_workspace_sync_objects sync_objects_task = google_workspace_sync_objects

View File

@ -7,7 +7,6 @@ from django.db import models
from django.db.models import QuerySet from django.db.models import QuerySet
from django.templatetags.static import static from django.templatetags.static import static
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from dramatiq.actor import Actor
from google.oauth2.service_account import Credentials from google.oauth2.service_account import Credentials
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
@ -111,12 +110,6 @@ class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider):
help_text=_("Property mappings used for group creation/updating."), help_text=_("Property mappings used for group creation/updating."),
) )
@property
def sync_actor(self) -> Actor:
from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync
return google_workspace_sync
def client_for_model( def client_for_model(
self, self,
model: type[User | Group | GoogleWorkspaceProviderUser | GoogleWorkspaceProviderGroup], model: type[User | Group | GoogleWorkspaceProviderUser | GoogleWorkspaceProviderGroup],

View File

@ -0,0 +1,13 @@
"""Google workspace provider task Settings"""
from celery.schedules import crontab
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"providers_google_workspace_sync": {
"task": "authentik.enterprise.providers.google_workspace.tasks.google_workspace_sync_all",
"schedule": crontab(minute=fqdn_rand("google_workspace_sync_all"), hour="*/4"),
"options": {"queue": "authentik_scheduled"},
},
}

View File

@ -2,13 +2,15 @@
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
from authentik.enterprise.providers.google_workspace.tasks import ( from authentik.enterprise.providers.google_workspace.tasks import (
google_workspace_sync_direct_dispatch, google_workspace_sync,
google_workspace_sync_m2m_dispatch, google_workspace_sync_direct,
google_workspace_sync_m2m,
) )
from authentik.lib.sync.outgoing.signals import register_signals from authentik.lib.sync.outgoing.signals import register_signals
register_signals( register_signals(
GoogleWorkspaceProvider, GoogleWorkspaceProvider,
task_sync_direct_dispatch=google_workspace_sync_direct_dispatch, task_sync_single=google_workspace_sync,
task_sync_m2m_dispatch=google_workspace_sync_m2m_dispatch, task_sync_direct=google_workspace_sync_direct,
task_sync_m2m=google_workspace_sync_m2m,
) )

View File

@ -1,48 +1,37 @@
"""Google Provider tasks""" """Google Provider tasks"""
from django.utils.translation import gettext_lazy as _
from dramatiq.actor import actor
from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider
from authentik.events.system_tasks import SystemTask
from authentik.lib.sync.outgoing.exceptions import TransientSyncException
from authentik.lib.sync.outgoing.tasks import SyncTasks from authentik.lib.sync.outgoing.tasks import SyncTasks
from authentik.root.celery import CELERY_APP
sync_tasks = SyncTasks(GoogleWorkspaceProvider) sync_tasks = SyncTasks(GoogleWorkspaceProvider)
@actor(description=_("Sync Google Workspace provider objects.")) @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
def google_workspace_sync_objects(*args, **kwargs): def google_workspace_sync_objects(*args, **kwargs):
return sync_tasks.sync_objects(*args, **kwargs) return sync_tasks.sync_objects(*args, **kwargs)
@actor(description=_("Full sync for Google Workspace provider.")) @CELERY_APP.task(
def google_workspace_sync(provider_pk: int, *args, **kwargs): base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True
)
def google_workspace_sync(self, provider_pk: int, *args, **kwargs):
"""Run full sync for Google Workspace provider""" """Run full sync for Google Workspace provider"""
return sync_tasks.sync(provider_pk, google_workspace_sync_objects) return sync_tasks.sync_single(self, provider_pk, google_workspace_sync_objects)
@actor(description=_("Sync a direct object (user, group) for Google Workspace provider.")) @CELERY_APP.task()
def google_workspace_sync_all():
return sync_tasks.sync_all(google_workspace_sync)
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
def google_workspace_sync_direct(*args, **kwargs): def google_workspace_sync_direct(*args, **kwargs):
return sync_tasks.sync_signal_direct(*args, **kwargs) return sync_tasks.sync_signal_direct(*args, **kwargs)
@actor( @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
description=_(
"Dispatch syncs for a direct object (user, group) for Google Workspace providers."
)
)
def google_workspace_sync_direct_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_direct_dispatch(google_workspace_sync_direct, *args, **kwargs)
@actor(description=_("Sync a related object (memberships) for Google Workspace provider."))
def google_workspace_sync_m2m(*args, **kwargs): def google_workspace_sync_m2m(*args, **kwargs):
return sync_tasks.sync_signal_m2m(*args, **kwargs) return sync_tasks.sync_signal_m2m(*args, **kwargs)
@actor(
description=_(
"Dispatch syncs for a related object (memberships) for Google Workspace providers."
)
)
def google_workspace_sync_m2m_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_m2m_dispatch(google_workspace_sync_m2m, *args, **kwargs)

View File

@ -324,7 +324,7 @@ class GoogleWorkspaceGroupTests(TestCase):
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials", "authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}), MagicMock(return_value={"developerKey": self.api_key, "http": http}),
): ):
google_workspace_sync.send(self.provider.pk).get_result() google_workspace_sync.delay(self.provider.pk).get()
self.assertTrue( self.assertTrue(
GoogleWorkspaceProviderGroup.objects.filter( GoogleWorkspaceProviderGroup.objects.filter(
group=different_group, provider=self.provider group=different_group, provider=self.provider

View File

@ -302,7 +302,7 @@ class GoogleWorkspaceUserTests(TestCase):
"authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials", "authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials",
MagicMock(return_value={"developerKey": self.api_key, "http": http}), MagicMock(return_value={"developerKey": self.api_key, "http": http}),
): ):
google_workspace_sync.send(self.provider.pk).get_result() google_workspace_sync.delay(self.provider.pk).get()
self.assertTrue( self.assertTrue(
GoogleWorkspaceProviderUser.objects.filter( GoogleWorkspaceProviderUser.objects.filter(
user=different_user, provider=self.provider user=different_user, provider=self.provider

View File

@ -53,5 +53,5 @@ class MicrosoftEntraProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin
] ]
search_fields = ["name"] search_fields = ["name"]
ordering = ["name"] ordering = ["name"]
sync_task = microsoft_entra_sync sync_single_task = microsoft_entra_sync
sync_objects_task = microsoft_entra_sync_objects sync_objects_task = microsoft_entra_sync_objects

View File

@ -8,7 +8,6 @@ from django.db import models
from django.db.models import QuerySet from django.db.models import QuerySet
from django.templatetags.static import static from django.templatetags.static import static
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from dramatiq.actor import Actor
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
from authentik.core.models import ( from authentik.core.models import (
@ -100,12 +99,6 @@ class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider):
help_text=_("Property mappings used for group creation/updating."), help_text=_("Property mappings used for group creation/updating."),
) )
@property
def sync_actor(self) -> Actor:
from authentik.enterprise.providers.microsoft_entra.tasks import microsoft_entra_sync
return microsoft_entra_sync
def client_for_model( def client_for_model(
self, self,
model: type[User | Group | MicrosoftEntraProviderUser | MicrosoftEntraProviderGroup], model: type[User | Group | MicrosoftEntraProviderUser | MicrosoftEntraProviderGroup],

View File

@ -0,0 +1,13 @@
"""Microsoft Entra provider task Settings"""
from celery.schedules import crontab
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"providers_microsoft_entra_sync": {
"task": "authentik.enterprise.providers.microsoft_entra.tasks.microsoft_entra_sync_all",
"schedule": crontab(minute=fqdn_rand("microsoft_entra_sync_all"), hour="*/4"),
"options": {"queue": "authentik_scheduled"},
},
}

View File

@ -2,13 +2,15 @@
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
from authentik.enterprise.providers.microsoft_entra.tasks import ( from authentik.enterprise.providers.microsoft_entra.tasks import (
microsoft_entra_sync_direct_dispatch, microsoft_entra_sync,
microsoft_entra_sync_m2m_dispatch, microsoft_entra_sync_direct,
microsoft_entra_sync_m2m,
) )
from authentik.lib.sync.outgoing.signals import register_signals from authentik.lib.sync.outgoing.signals import register_signals
register_signals( register_signals(
MicrosoftEntraProvider, MicrosoftEntraProvider,
task_sync_direct_dispatch=microsoft_entra_sync_direct_dispatch, task_sync_single=microsoft_entra_sync,
task_sync_m2m_dispatch=microsoft_entra_sync_m2m_dispatch, task_sync_direct=microsoft_entra_sync_direct,
task_sync_m2m=microsoft_entra_sync_m2m,
) )

View File

@ -1,46 +1,37 @@
"""Microsoft Entra Provider tasks""" """Microsoft Entra Provider tasks"""
from django.utils.translation import gettext_lazy as _
from dramatiq.actor import actor
from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
from authentik.events.system_tasks import SystemTask
from authentik.lib.sync.outgoing.exceptions import TransientSyncException
from authentik.lib.sync.outgoing.tasks import SyncTasks from authentik.lib.sync.outgoing.tasks import SyncTasks
from authentik.root.celery import CELERY_APP
sync_tasks = SyncTasks(MicrosoftEntraProvider) sync_tasks = SyncTasks(MicrosoftEntraProvider)
@actor(description=_("Sync Microsoft Entra provider objects.")) @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
def microsoft_entra_sync_objects(*args, **kwargs): def microsoft_entra_sync_objects(*args, **kwargs):
return sync_tasks.sync_objects(*args, **kwargs) return sync_tasks.sync_objects(*args, **kwargs)
@actor(description=_("Full sync for Microsoft Entra provider.")) @CELERY_APP.task(
def microsoft_entra_sync(provider_pk: int, *args, **kwargs): base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True
)
def microsoft_entra_sync(self, provider_pk: int, *args, **kwargs):
"""Run full sync for Microsoft Entra provider""" """Run full sync for Microsoft Entra provider"""
return sync_tasks.sync(provider_pk, microsoft_entra_sync_objects) return sync_tasks.sync_single(self, provider_pk, microsoft_entra_sync_objects)
@actor(description=_("Sync a direct object (user, group) for Microsoft Entra provider.")) @CELERY_APP.task()
def microsoft_entra_sync_all():
return sync_tasks.sync_all(microsoft_entra_sync)
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
def microsoft_entra_sync_direct(*args, **kwargs): def microsoft_entra_sync_direct(*args, **kwargs):
return sync_tasks.sync_signal_direct(*args, **kwargs) return sync_tasks.sync_signal_direct(*args, **kwargs)
@actor( @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
description=_("Dispatch syncs for a direct object (user, group) for Microsoft Entra providers.")
)
def microsoft_entra_sync_direct_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_direct_dispatch(microsoft_entra_sync_direct, *args, **kwargs)
@actor(description=_("Sync a related object (memberships) for Microsoft Entra provider."))
def microsoft_entra_sync_m2m(*args, **kwargs): def microsoft_entra_sync_m2m(*args, **kwargs):
return sync_tasks.sync_signal_m2m(*args, **kwargs) return sync_tasks.sync_signal_m2m(*args, **kwargs)
@actor(
description=_(
"Dispatch syncs for a related object (memberships) for Microsoft Entra providers."
)
)
def microsoft_entra_sync_m2m_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_m2m_dispatch(microsoft_entra_sync_m2m, *args, **kwargs)

View File

@ -252,13 +252,9 @@ class MicrosoftEntraGroupTests(TestCase):
member_add.assert_called_once() member_add.assert_called_once()
self.assertEqual( self.assertEqual(
member_add.call_args[0][0].odata_id, member_add.call_args[0][0].odata_id,
f"https://graph.microsoft.com/v1.0/directoryObjects/{ f"https://graph.microsoft.com/v1.0/directoryObjects/{MicrosoftEntraProviderUser.objects.filter(
MicrosoftEntraProviderUser.objects.filter(
provider=self.provider, provider=self.provider,
) ).first().microsoft_id}",
.first()
.microsoft_id
}",
) )
def test_group_create_member_remove(self): def test_group_create_member_remove(self):
@ -315,13 +311,9 @@ class MicrosoftEntraGroupTests(TestCase):
member_add.assert_called_once() member_add.assert_called_once()
self.assertEqual( self.assertEqual(
member_add.call_args[0][0].odata_id, member_add.call_args[0][0].odata_id,
f"https://graph.microsoft.com/v1.0/directoryObjects/{ f"https://graph.microsoft.com/v1.0/directoryObjects/{MicrosoftEntraProviderUser.objects.filter(
MicrosoftEntraProviderUser.objects.filter(
provider=self.provider, provider=self.provider,
) ).first().microsoft_id}",
.first()
.microsoft_id
}",
) )
member_remove.assert_called_once() member_remove.assert_called_once()
@ -421,7 +413,7 @@ class MicrosoftEntraGroupTests(TestCase):
), ),
) as group_list, ) as group_list,
): ):
microsoft_entra_sync.send(self.provider.pk).get_result() microsoft_entra_sync.delay(self.provider.pk).get()
self.assertTrue( self.assertTrue(
MicrosoftEntraProviderGroup.objects.filter( MicrosoftEntraProviderGroup.objects.filter(
group=different_group, provider=self.provider group=different_group, provider=self.provider

View File

@ -397,7 +397,7 @@ class MicrosoftEntraUserTests(APITestCase):
AsyncMock(return_value=GroupCollectionResponse(value=[])), AsyncMock(return_value=GroupCollectionResponse(value=[])),
), ),
): ):
microsoft_entra_sync.send(self.provider.pk).get_result() microsoft_entra_sync.delay(self.provider.pk).get()
self.assertTrue( self.assertTrue(
MicrosoftEntraProviderUser.objects.filter( MicrosoftEntraProviderUser.objects.filter(
user=different_user, provider=self.provider user=different_user, provider=self.provider

View File

@ -17,7 +17,6 @@ from authentik.crypto.models import CertificateKeyPair
from authentik.lib.models import CreatedUpdatedModel from authentik.lib.models import CreatedUpdatedModel
from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator
from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider
from authentik.tasks.models import TasksModel
class EventTypes(models.TextChoices): class EventTypes(models.TextChoices):
@ -43,7 +42,7 @@ class SSFEventStatus(models.TextChoices):
SENT = "sent" SENT = "sent"
class SSFProvider(TasksModel, BackchannelProvider): class SSFProvider(BackchannelProvider):
"""Shared Signals Framework provider to allow applications to """Shared Signals Framework provider to allow applications to
receive user events from authentik.""" receive user events from authentik."""

View File

@ -18,7 +18,7 @@ from authentik.enterprise.providers.ssf.models import (
EventTypes, EventTypes,
SSFProvider, SSFProvider,
) )
from authentik.enterprise.providers.ssf.tasks import send_ssf_events from authentik.enterprise.providers.ssf.tasks import send_ssf_event
from authentik.events.middleware import audit_ignore from authentik.events.middleware import audit_ignore
from authentik.stages.authenticator.models import Device from authentik.stages.authenticator.models import Device
from authentik.stages.authenticator_duo.models import DuoDevice from authentik.stages.authenticator_duo.models import DuoDevice
@ -66,7 +66,7 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi
As this signal is also triggered with a regular logout, we can't be sure As this signal is also triggered with a regular logout, we can't be sure
if the session has been deleted by an admin or by the user themselves.""" if the session has been deleted by an admin or by the user themselves."""
send_ssf_events( send_ssf_event(
EventTypes.CAEP_SESSION_REVOKED, EventTypes.CAEP_SESSION_REVOKED,
{ {
"initiating_entity": "user", "initiating_entity": "user",
@ -88,7 +88,7 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi
@receiver(password_changed) @receiver(password_changed)
def ssf_password_changed_cred_change(sender, user: User, password: str | None, **_): def ssf_password_changed_cred_change(sender, user: User, password: str | None, **_):
"""Credential change trigger (password changed)""" """Credential change trigger (password changed)"""
send_ssf_events( send_ssf_event(
EventTypes.CAEP_CREDENTIAL_CHANGE, EventTypes.CAEP_CREDENTIAL_CHANGE,
{ {
"credential_type": "password", "credential_type": "password",
@ -126,7 +126,7 @@ def ssf_device_post_save(sender: type[Model], instance: Device, created: bool, *
} }
if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID: if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID:
data["fido2_aaguid"] = instance.aaguid data["fido2_aaguid"] = instance.aaguid
send_ssf_events( send_ssf_event(
EventTypes.CAEP_CREDENTIAL_CHANGE, EventTypes.CAEP_CREDENTIAL_CHANGE,
data, data,
sub_id={ sub_id={
@ -153,7 +153,7 @@ def ssf_device_post_delete(sender: type[Model], instance: Device, **_):
} }
if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID: if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID:
data["fido2_aaguid"] = instance.aaguid data["fido2_aaguid"] = instance.aaguid
send_ssf_events( send_ssf_event(
EventTypes.CAEP_CREDENTIAL_CHANGE, EventTypes.CAEP_CREDENTIAL_CHANGE,
data, data,
sub_id={ sub_id={

View File

@ -1,11 +1,7 @@
from typing import Any from celery import group
from uuid import UUID
from django.http import HttpRequest from django.http import HttpRequest
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from requests.exceptions import RequestException from requests.exceptions import RequestException
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -17,16 +13,19 @@ from authentik.enterprise.providers.ssf.models import (
Stream, Stream,
StreamEvent, StreamEvent,
) )
from authentik.events.logs import LogEvent
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.lib.utils.http import get_http_session from authentik.lib.utils.http import get_http_session
from authentik.lib.utils.time import timedelta_from_string from authentik.lib.utils.time import timedelta_from_string
from authentik.policies.engine import PolicyEngine from authentik.policies.engine import PolicyEngine
from authentik.tasks.models import Task from authentik.root.celery import CELERY_APP
session = get_http_session() session = get_http_session()
LOGGER = get_logger() LOGGER = get_logger()
def send_ssf_events( def send_ssf_event(
event_type: EventTypes, event_type: EventTypes,
data: dict, data: dict,
stream_filter: dict | None = None, stream_filter: dict | None = None,
@ -34,7 +33,7 @@ def send_ssf_events(
**extra_data, **extra_data,
): ):
"""Wrapper to send an SSF event to multiple streams""" """Wrapper to send an SSF event to multiple streams"""
events_data = {} payload = []
if not stream_filter: if not stream_filter:
stream_filter = {} stream_filter = {}
stream_filter["events_requested__contains"] = [event_type] stream_filter["events_requested__contains"] = [event_type]
@ -42,22 +41,16 @@ def send_ssf_events(
extra_data.setdefault("txn", request.request_id) extra_data.setdefault("txn", request.request_id)
for stream in Stream.objects.filter(**stream_filter): for stream in Stream.objects.filter(**stream_filter):
event_data = stream.prepare_event_payload(event_type, data, **extra_data) event_data = stream.prepare_event_payload(event_type, data, **extra_data)
events_data[stream.uuid] = event_data payload.append((str(stream.uuid), event_data))
ssf_events_dispatch.send(events_data) return _send_ssf_event.delay(payload)
@actor(description=_("Dispatch SSF events.")) def _check_app_access(stream_uuid: str, event_data: dict) -> bool:
def ssf_events_dispatch(events_data: dict[str, dict[str, Any]]):
for stream_uuid, event_data in events_data.items():
stream = Stream.objects.filter(pk=stream_uuid).first()
if not stream:
continue
send_ssf_event.send_with_options(args=(stream_uuid, event_data), rel_obj=stream.provider)
def _check_app_access(stream: Stream, event_data: dict) -> bool:
"""Check if event is related to user and if so, check """Check if event is related to user and if so, check
if the user has access to the application""" if the user has access to the application"""
stream = Stream.objects.filter(pk=stream_uuid).first()
if not stream:
return False
# `event_data` is a dict version of a StreamEvent # `event_data` is a dict version of a StreamEvent
sub_id = event_data.get("payload", {}).get("sub_id", {}) sub_id = event_data.get("payload", {}).get("sub_id", {})
email = sub_id.get("user", {}).get("email", None) email = sub_id.get("user", {}).get("email", None)
@ -72,22 +65,42 @@ def _check_app_access(stream: Stream, event_data: dict) -> bool:
return engine.passing return engine.passing
@actor(description=_("Send an SSF event.")) @CELERY_APP.task()
def send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]): def _send_ssf_event(event_data: list[tuple[str, dict]]):
self: Task = CurrentTask.get_task() tasks = []
for stream, data in event_data:
if not _check_app_access(stream, data):
continue
event = StreamEvent.objects.create(**data)
tasks.extend(send_single_ssf_event(stream, str(event.uuid)))
main_task = group(*tasks)
main_task()
stream = Stream.objects.filter(pk=stream_uuid).first()
def send_single_ssf_event(stream_id: str, evt_id: str):
stream = Stream.objects.filter(pk=stream_id).first()
if not stream: if not stream:
return return
if not _check_app_access(stream, event_data): event = StreamEvent.objects.filter(pk=evt_id).first()
if not event:
return return
event = StreamEvent.objects.create(**event_data)
self.set_uid(event.pk)
if event.status == SSFEventStatus.SENT: if event.status == SSFEventStatus.SENT:
return return
if stream.delivery_method != DeliveryMethods.RISC_PUSH: if stream.delivery_method == DeliveryMethods.RISC_PUSH:
return return [ssf_push_event.si(str(event.pk))]
return []
@CELERY_APP.task(bind=True, base=SystemTask)
def ssf_push_event(self: SystemTask, event_id: str):
self.save_on_success = False
event = StreamEvent.objects.filter(pk=event_id).first()
if not event:
return
self.set_uid(event_id)
if event.status == SSFEventStatus.SENT:
self.set_status(TaskStatus.SUCCESSFUL)
return
try: try:
response = session.post( response = session.post(
event.stream.endpoint_url, event.stream.endpoint_url,
@ -97,17 +110,26 @@ def send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]):
response.raise_for_status() response.raise_for_status()
event.status = SSFEventStatus.SENT event.status = SSFEventStatus.SENT
event.save() event.save()
self.set_status(TaskStatus.SUCCESSFUL)
return return
except RequestException as exc: except RequestException as exc:
LOGGER.warning("Failed to send SSF event", exc=exc) LOGGER.warning("Failed to send SSF event", exc=exc)
self.set_status(TaskStatus.ERROR)
attrs = {} attrs = {}
if exc.response: if exc.response:
attrs["response"] = { attrs["response"] = {
"content": exc.response.text, "content": exc.response.text,
"status": exc.response.status_code, "status": exc.response.status_code,
} }
self.warning(exc) self.set_error(
self.warning("Failed to send request", **attrs) exc,
LogEvent(
_("Failed to send request"),
log_level="warning",
logger=self.__name__,
attributes=attrs,
),
)
# Re-up the expiry of the stream event # Re-up the expiry of the stream event
event.expires = now() + timedelta_from_string(event.stream.provider.event_retention) event.expires = now() + timedelta_from_string(event.stream.provider.event_retention)
event.status = SSFEventStatus.PENDING_FAILED event.status = SSFEventStatus.PENDING_FAILED

View File

@ -13,7 +13,7 @@ from authentik.enterprise.providers.ssf.models import (
SSFProvider, SSFProvider,
Stream, Stream,
) )
from authentik.enterprise.providers.ssf.tasks import send_ssf_events from authentik.enterprise.providers.ssf.tasks import send_ssf_event
from authentik.enterprise.providers.ssf.views.base import SSFView from authentik.enterprise.providers.ssf.views.base import SSFView
LOGGER = get_logger() LOGGER = get_logger()
@ -109,7 +109,7 @@ class StreamView(SSFView):
"User does not have permission to create stream for this provider." "User does not have permission to create stream for this provider."
) )
instance: Stream = stream.save(provider=self.provider) instance: Stream = stream.save(provider=self.provider)
send_ssf_events( send_ssf_event(
EventTypes.SET_VERIFICATION, EventTypes.SET_VERIFICATION,
{ {
"state": None, "state": None,

View File

@ -1,5 +1,17 @@
"""Enterprise additional settings""" """Enterprise additional settings"""
from celery.schedules import crontab
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"enterprise_update_usage": {
"task": "authentik.enterprise.tasks.enterprise_update_usage",
"schedule": crontab(minute=fqdn_rand("enterprise_update_usage"), hour="*/2"),
"options": {"queue": "authentik_scheduled"},
}
}
TENANT_APPS = [ TENANT_APPS = [
"authentik.enterprise.audit", "authentik.enterprise.audit",
"authentik.enterprise.policies.unique_password", "authentik.enterprise.policies.unique_password",

View File

@ -10,7 +10,6 @@ from django.utils.timezone import get_current_timezone
from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE
from authentik.enterprise.models import License from authentik.enterprise.models import License
from authentik.enterprise.tasks import enterprise_update_usage from authentik.enterprise.tasks import enterprise_update_usage
from authentik.tasks.schedules.models import Schedule
@receiver(pre_save, sender=License) @receiver(pre_save, sender=License)
@ -27,7 +26,7 @@ def pre_save_license(sender: type[License], instance: License, **_):
def post_save_license(sender: type[License], instance: License, **_): def post_save_license(sender: type[License], instance: License, **_):
"""Trigger license usage calculation when license is saved""" """Trigger license usage calculation when license is saved"""
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
Schedule.dispatch_by_actor(enterprise_update_usage) enterprise_update_usage.delay()
@receiver(post_delete, sender=License) @receiver(post_delete, sender=License)

View File

@ -1,11 +1,14 @@
"""Enterprise tasks""" """Enterprise tasks"""
from django.utils.translation import gettext_lazy as _
from dramatiq.actor import actor
from authentik.enterprise.license import LicenseKey from authentik.enterprise.license import LicenseKey
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.root.celery import CELERY_APP
@actor(description=_("Update enterprise license status.")) @CELERY_APP.task(bind=True, base=SystemTask)
def enterprise_update_usage(): @prefill_task
def enterprise_update_usage(self: SystemTask):
"""Update enterprise license status"""
LicenseKey.get_total().record_usage() LicenseKey.get_total().record_usage()
self.set_status(TaskStatus.SUCCESSFUL)

View File

@ -0,0 +1,104 @@
"""Tasks API"""
from importlib import import_module
from django.contrib import messages
from django.utils.translation import gettext_lazy as _
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiResponse, extend_schema
from rest_framework.decorators import action
from rest_framework.fields import (
CharField,
ChoiceField,
DateTimeField,
FloatField,
SerializerMethodField,
)
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ReadOnlyModelViewSet
from structlog.stdlib import get_logger
from authentik.core.api.utils import ModelSerializer
from authentik.events.logs import LogEventSerializer
from authentik.events.models import SystemTask, TaskStatus
from authentik.rbac.decorators import permission_required
LOGGER = get_logger()
class SystemTaskSerializer(ModelSerializer):
"""Serialize TaskInfo and TaskResult"""
name = CharField()
full_name = SerializerMethodField()
uid = CharField(required=False)
description = CharField()
start_timestamp = DateTimeField(read_only=True)
finish_timestamp = DateTimeField(read_only=True)
duration = FloatField(read_only=True)
status = ChoiceField(choices=[(x.value, x.name) for x in TaskStatus])
messages = LogEventSerializer(many=True)
def get_full_name(self, instance: SystemTask) -> str:
"""Get full name with UID"""
if instance.uid:
return f"{instance.name}:{instance.uid}"
return instance.name
class Meta:
model = SystemTask
fields = [
"uuid",
"name",
"full_name",
"uid",
"description",
"start_timestamp",
"finish_timestamp",
"duration",
"status",
"messages",
"expires",
"expiring",
]
class SystemTaskViewSet(ReadOnlyModelViewSet):
"""Read-only view set that returns all background tasks"""
queryset = SystemTask.objects.all()
serializer_class = SystemTaskSerializer
filterset_fields = ["name", "uid", "status"]
ordering = ["name", "uid", "status"]
search_fields = ["name", "description", "uid", "status"]
@permission_required(None, ["authentik_events.run_task"])
@extend_schema(
request=OpenApiTypes.NONE,
responses={
204: OpenApiResponse(description="Task retried successfully"),
404: OpenApiResponse(description="Task not found"),
500: OpenApiResponse(description="Failed to retry task"),
},
)
@action(detail=True, methods=["POST"], permission_classes=[])
def run(self, request: Request, pk=None) -> Response:
"""Run task"""
task: SystemTask = self.get_object()
try:
task_module = import_module(task.task_call_module)
task_func = getattr(task_module, task.task_call_func)
LOGGER.info("Running task", task=task_func)
task_func.delay(*task.task_call_args, **task.task_call_kwargs)
messages.success(
self.request,
_("Successfully started task {name}.".format_map({"name": task.name})),
)
return Response(status=204)
except (ImportError, AttributeError) as exc: # pragma: no cover
LOGGER.warning("Failed to run task, remove state", task=task.name, exc=exc)
# if we get an import error, the module path has probably changed
task.delete()
return Response(status=500)

View File

@ -1,11 +1,12 @@
"""authentik events app""" """authentik events app"""
from celery.schedules import crontab
from prometheus_client import Gauge, Histogram from prometheus_client import Gauge, Histogram
from authentik.blueprints.apps import ManagedAppConfig from authentik.blueprints.apps import ManagedAppConfig
from authentik.lib.config import CONFIG, ENV_PREFIX from authentik.lib.config import CONFIG, ENV_PREFIX
from authentik.lib.utils.time import fqdn_rand from authentik.lib.utils.reflection import path_to_class
from authentik.tasks.schedules.lib import ScheduleSpec from authentik.root.celery import CELERY_APP
# TODO: Deprecated metric - remove in 2024.2 or later # TODO: Deprecated metric - remove in 2024.2 or later
GAUGE_TASKS = Gauge( GAUGE_TASKS = Gauge(
@ -34,17 +35,6 @@ class AuthentikEventsConfig(ManagedAppConfig):
verbose_name = "authentik Events" verbose_name = "authentik Events"
default = True default = True
@property
def tenant_schedule_specs(self) -> list[ScheduleSpec]:
from authentik.events.tasks import notification_cleanup
return [
ScheduleSpec(
actor=notification_cleanup,
crontab=f"{fqdn_rand('notification_cleanup')} */8 * * *",
),
]
@ManagedAppConfig.reconcile_global @ManagedAppConfig.reconcile_global
def check_deprecations(self): def check_deprecations(self):
"""Check for config deprecations""" """Check for config deprecations"""
@ -66,3 +56,41 @@ class AuthentikEventsConfig(ManagedAppConfig):
replacement_env=replace_env, replacement_env=replace_env,
message=msg, message=msg,
).save() ).save()
@ManagedAppConfig.reconcile_tenant
def prefill_tasks(self):
"""Prefill tasks"""
from authentik.events.models import SystemTask
from authentik.events.system_tasks import _prefill_tasks
for task in _prefill_tasks:
if SystemTask.objects.filter(name=task.name).exists():
continue
task.save()
self.logger.debug("prefilled task", task_name=task.name)
@ManagedAppConfig.reconcile_tenant
def run_scheduled_tasks(self):
"""Run schedule tasks which are behind schedule (only applies
to tasks of which we keep metrics)"""
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask as CelerySystemTask
for task in CELERY_APP.conf["beat_schedule"].values():
schedule = task["schedule"]
if not isinstance(schedule, crontab):
continue
task_class: CelerySystemTask = path_to_class(task["task"])
if not isinstance(task_class, CelerySystemTask):
continue
db_task = task_class.db()
if not db_task:
continue
due, _ = schedule.is_due(db_task.finish_timestamp)
if due or db_task.status == TaskStatus.UNKNOWN:
self.logger.debug("Running past-due scheduled task", task=task["task"])
task_class.apply_async(
args=task.get("args", None),
kwargs=task.get("kwargs", None),
**task.get("options", {}),
)

View File

@ -1,22 +0,0 @@
# Generated by Django 5.1.11 on 2025-06-24 15:36
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("authentik_events", "0010_rename_group_notificationrule_destination_group_and_more"),
]
operations = [
migrations.AlterModelOptions(
name="systemtask",
options={
"default_permissions": (),
"permissions": (),
"verbose_name": "System Task",
"verbose_name_plural": "System Tasks",
},
),
]

View File

@ -5,11 +5,12 @@ from datetime import timedelta
from difflib import get_close_matches from difflib import get_close_matches
from functools import lru_cache from functools import lru_cache
from inspect import currentframe from inspect import currentframe
from smtplib import SMTPException
from typing import Any from typing import Any
from uuid import uuid4 from uuid import uuid4
from django.apps import apps from django.apps import apps
from django.db import models from django.db import connection, models
from django.http import HttpRequest from django.http import HttpRequest
from django.http.request import QueryDict from django.http.request import QueryDict
from django.utils.timezone import now from django.utils.timezone import now
@ -26,6 +27,7 @@ from authentik.core.middleware import (
SESSION_KEY_IMPERSONATE_USER, SESSION_KEY_IMPERSONATE_USER,
) )
from authentik.core.models import ExpiringModel, Group, PropertyMapping, User from authentik.core.models import ExpiringModel, Group, PropertyMapping, User
from authentik.events.apps import GAUGE_TASKS, SYSTEM_TASK_STATUS, SYSTEM_TASK_TIME
from authentik.events.context_processors.base import get_context_processors from authentik.events.context_processors.base import get_context_processors
from authentik.events.utils import ( from authentik.events.utils import (
cleanse_dict, cleanse_dict,
@ -41,7 +43,6 @@ from authentik.lib.utils.time import timedelta_from_string
from authentik.policies.models import PolicyBindingModel from authentik.policies.models import PolicyBindingModel
from authentik.root.middleware import ClientIPMiddleware from authentik.root.middleware import ClientIPMiddleware
from authentik.stages.email.utils import TemplateEmailMessage from authentik.stages.email.utils import TemplateEmailMessage
from authentik.tasks.models import TasksModel
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
from authentik.tenants.utils import get_current_tenant from authentik.tenants.utils import get_current_tenant
@ -266,8 +267,7 @@ class Event(SerializerModel, ExpiringModel):
models.Index(fields=["created"]), models.Index(fields=["created"]),
models.Index(fields=["client_ip"]), models.Index(fields=["client_ip"]),
models.Index( models.Index(
models.F("context__authorized_application"), models.F("context__authorized_application"), name="authentik_e_ctx_app__idx"
name="authentik_e_ctx_app__idx",
), ),
] ]
@ -281,7 +281,7 @@ class TransportMode(models.TextChoices):
EMAIL = "email", _("Email") EMAIL = "email", _("Email")
class NotificationTransport(TasksModel, SerializerModel): class NotificationTransport(SerializerModel):
"""Action which is executed when a Rule matches""" """Action which is executed when a Rule matches"""
uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
@ -446,8 +446,6 @@ class NotificationTransport(TasksModel, SerializerModel):
def send_email(self, notification: "Notification") -> list[str]: def send_email(self, notification: "Notification") -> list[str]:
"""Send notification via global email configuration""" """Send notification via global email configuration"""
from authentik.stages.email.tasks import send_mail
if notification.user.email.strip() == "": if notification.user.email.strip() == "":
LOGGER.info( LOGGER.info(
"Discarding notification as user has no email address", "Discarding notification as user has no email address",
@ -489,14 +487,17 @@ class NotificationTransport(TasksModel, SerializerModel):
template_name="email/event_notification.html", template_name="email/event_notification.html",
template_context=context, template_context=context,
) )
send_mail.send_with_options(args=(mail.__dict__,), rel_obj=self) # Email is sent directly here, as the call to send() should have been from a task.
return [] try:
from authentik.stages.email.tasks import send_mail
return send_mail(mail.__dict__)
except (SMTPException, ConnectionError, OSError) as exc:
raise NotificationTransportError(exc) from exc
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:
from authentik.events.api.notification_transports import ( from authentik.events.api.notification_transports import NotificationTransportSerializer
NotificationTransportSerializer,
)
return NotificationTransportSerializer return NotificationTransportSerializer
@ -546,7 +547,7 @@ class Notification(SerializerModel):
verbose_name_plural = _("Notifications") verbose_name_plural = _("Notifications")
class NotificationRule(TasksModel, SerializerModel, PolicyBindingModel): class NotificationRule(SerializerModel, PolicyBindingModel):
"""Decide when to create a Notification based on policies attached to this object.""" """Decide when to create a Notification based on policies attached to this object."""
name = models.TextField(unique=True) name = models.TextField(unique=True)
@ -610,9 +611,7 @@ class NotificationWebhookMapping(PropertyMapping):
@property @property
def serializer(self) -> type[type[Serializer]]: def serializer(self) -> type[type[Serializer]]:
from authentik.events.api.notification_mappings import ( from authentik.events.api.notification_mappings import NotificationWebhookMappingSerializer
NotificationWebhookMappingSerializer,
)
return NotificationWebhookMappingSerializer return NotificationWebhookMappingSerializer
@ -625,7 +624,7 @@ class NotificationWebhookMapping(PropertyMapping):
class TaskStatus(models.TextChoices): class TaskStatus(models.TextChoices):
"""DEPRECATED do not use""" """Possible states of tasks"""
UNKNOWN = "unknown" UNKNOWN = "unknown"
SUCCESSFUL = "successful" SUCCESSFUL = "successful"
@ -633,8 +632,8 @@ class TaskStatus(models.TextChoices):
ERROR = "error" ERROR = "error"
class SystemTask(ExpiringModel): class SystemTask(SerializerModel, ExpiringModel):
"""DEPRECATED do not use""" """Info about a system task running in the background along with details to restart the task"""
uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
name = models.TextField() name = models.TextField()
@ -654,13 +653,41 @@ class SystemTask(ExpiringModel):
task_call_args = models.JSONField(default=list) task_call_args = models.JSONField(default=list)
task_call_kwargs = models.JSONField(default=dict) task_call_kwargs = models.JSONField(default=dict)
@property
def serializer(self) -> type[Serializer]:
from authentik.events.api.tasks import SystemTaskSerializer
return SystemTaskSerializer
def update_metrics(self):
"""Update prometheus metrics"""
# TODO: Deprecated metric - remove in 2024.2 or later
GAUGE_TASKS.labels(
tenant=connection.schema_name,
task_name=self.name,
task_uid=self.uid or "",
status=self.status.lower(),
).set(self.duration)
SYSTEM_TASK_TIME.labels(
tenant=connection.schema_name,
task_name=self.name,
task_uid=self.uid or "",
).observe(self.duration)
SYSTEM_TASK_STATUS.labels(
tenant=connection.schema_name,
task_name=self.name,
task_uid=self.uid or "",
status=self.status.lower(),
).inc()
def __str__(self) -> str: def __str__(self) -> str:
return f"System Task {self.name}" return f"System Task {self.name}"
class Meta: class Meta:
unique_together = (("name", "uid"),) unique_together = (("name", "uid"),)
default_permissions = () # Remove "add", "change" and "delete" permissions as those are not used
permissions = () default_permissions = ["view"]
permissions = [("run_task", _("Run task"))]
verbose_name = _("System Task") verbose_name = _("System Task")
verbose_name_plural = _("System Tasks") verbose_name_plural = _("System Tasks")
indexes = ExpiringModel.Meta.indexes indexes = ExpiringModel.Meta.indexes

View File

@ -0,0 +1,13 @@
"""Event Settings"""
from celery.schedules import crontab
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"events_notification_cleanup": {
"task": "authentik.events.tasks.notification_cleanup",
"schedule": crontab(minute=fqdn_rand("notification_cleanup"), hour="*/8"),
"options": {"queue": "authentik_scheduled"},
},
}

View File

@ -12,10 +12,13 @@ from rest_framework.request import Request
from authentik.core.models import AuthenticatedSession, User from authentik.core.models import AuthenticatedSession, User
from authentik.core.signals import login_failed, password_changed from authentik.core.signals import login_failed, password_changed
from authentik.events.models import Event, EventAction from authentik.events.apps import SYSTEM_TASK_STATUS
from authentik.events.models import Event, EventAction, SystemTask
from authentik.events.tasks import event_notification_handler, gdpr_cleanup
from authentik.flows.models import Stage from authentik.flows.models import Stage
from authentik.flows.planner import PLAN_CONTEXT_OUTPOST, PLAN_CONTEXT_SOURCE, FlowPlan from authentik.flows.planner import PLAN_CONTEXT_OUTPOST, PLAN_CONTEXT_SOURCE, FlowPlan
from authentik.flows.views.executor import SESSION_KEY_PLAN from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.root.monitoring import monitoring_set
from authentik.stages.invitation.models import Invitation from authentik.stages.invitation.models import Invitation
from authentik.stages.invitation.signals import invitation_used from authentik.stages.invitation.signals import invitation_used
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
@ -111,15 +114,19 @@ def on_password_changed(sender, user: User, password: str, request: HttpRequest
@receiver(post_save, sender=Event) @receiver(post_save, sender=Event)
def event_post_save_notification(sender, instance: Event, **_): def event_post_save_notification(sender, instance: Event, **_):
"""Start task to check if any policies trigger an notification on this event""" """Start task to check if any policies trigger an notification on this event"""
from authentik.events.tasks import event_trigger_dispatch event_notification_handler.delay(instance.event_uuid.hex)
event_trigger_dispatch.send(instance.event_uuid)
@receiver(pre_delete, sender=User) @receiver(pre_delete, sender=User)
def event_user_pre_delete_cleanup(sender, instance: User, **_): def event_user_pre_delete_cleanup(sender, instance: User, **_):
"""If gdpr_compliance is enabled, remove all the user's events""" """If gdpr_compliance is enabled, remove all the user's events"""
from authentik.events.tasks import gdpr_cleanup
if get_current_tenant().gdpr_compliance: if get_current_tenant().gdpr_compliance:
gdpr_cleanup.send(instance.pk) gdpr_cleanup.delay(instance.pk)
@receiver(monitoring_set)
def monitoring_system_task(sender, **_):
"""Update metrics when task is saved"""
SYSTEM_TASK_STATUS.clear()
for task in SystemTask.objects.all():
task.update_metrics()

View File

@ -0,0 +1,156 @@
"""Monitored tasks"""
from datetime import datetime, timedelta
from time import perf_counter
from typing import Any
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from structlog.stdlib import BoundLogger, get_logger
from tenant_schemas_celery.task import TenantTask
from authentik.events.logs import LogEvent
from authentik.events.models import Event, EventAction, TaskStatus
from authentik.events.models import SystemTask as DBSystemTask
from authentik.events.utils import sanitize_item
from authentik.lib.utils.errors import exception_to_string
class SystemTask(TenantTask):
"""Task which can save its state to the cache"""
logger: BoundLogger
# For tasks that should only be listed if they failed, set this to False
save_on_success: bool
_status: TaskStatus
_messages: list[LogEvent]
_uid: str | None
# Precise start time from perf_counter
_start_precise: float | None = None
_start: datetime | None = None
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._status = TaskStatus.SUCCESSFUL
self.save_on_success = True
self._uid = None
self._status = None
self._messages = []
self.result_timeout_hours = 6
def set_uid(self, uid: str):
"""Set UID, so in the case of an unexpected error its saved correctly"""
self._uid = uid
def set_status(self, status: TaskStatus, *messages: LogEvent):
"""Set result for current run, will overwrite previous result."""
self._status = status
self._messages = list(messages)
for idx, msg in enumerate(self._messages):
if not isinstance(msg, LogEvent):
self._messages[idx] = LogEvent(msg, logger=self.__name__, log_level="info")
def set_error(self, exception: Exception, *messages: LogEvent):
"""Set result to error and save exception"""
self._status = TaskStatus.ERROR
self._messages = list(messages)
self._messages.extend(
[LogEvent(exception_to_string(exception), logger=self.__name__, log_level="error")]
)
def before_start(self, task_id, args, kwargs):
self._start_precise = perf_counter()
self._start = now()
self.logger = get_logger().bind(task_id=task_id)
return super().before_start(task_id, args, kwargs)
def db(self) -> DBSystemTask | None:
"""Get DB object for latest task"""
return DBSystemTask.objects.filter(
name=self.__name__,
uid=self._uid,
).first()
def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo):
super().after_return(status, retval, task_id, args, kwargs, einfo=einfo)
if not self._status:
return
if self._status == TaskStatus.SUCCESSFUL and not self.save_on_success:
DBSystemTask.objects.filter(
name=self.__name__,
uid=self._uid,
).delete()
return
DBSystemTask.objects.update_or_create(
name=self.__name__,
uid=self._uid,
defaults={
"description": self.__doc__,
"start_timestamp": self._start or now(),
"finish_timestamp": now(),
"duration": max(perf_counter() - self._start_precise, 0),
"task_call_module": self.__module__,
"task_call_func": self.__name__,
"task_call_args": sanitize_item(args),
"task_call_kwargs": sanitize_item(kwargs),
"status": self._status,
"messages": sanitize_item(self._messages),
"expires": now() + timedelta(hours=self.result_timeout_hours),
"expiring": True,
},
)
def on_failure(self, exc, task_id, args, kwargs, einfo):
super().on_failure(exc, task_id, args, kwargs, einfo=einfo)
if not self._status:
self.set_error(exc)
DBSystemTask.objects.update_or_create(
name=self.__name__,
uid=self._uid,
defaults={
"description": self.__doc__,
"start_timestamp": self._start or now(),
"finish_timestamp": now(),
"duration": max(perf_counter() - self._start_precise, 0),
"task_call_module": self.__module__,
"task_call_func": self.__name__,
"task_call_args": sanitize_item(args),
"task_call_kwargs": sanitize_item(kwargs),
"status": self._status,
"messages": sanitize_item(self._messages),
"expires": now() + timedelta(hours=self.result_timeout_hours + 3),
"expiring": True,
},
)
Event.new(
EventAction.SYSTEM_TASK_EXCEPTION,
message=f"Task {self.__name__} encountered an error: {exception_to_string(exc)}",
).save()
def run(self, *args, **kwargs):
raise NotImplementedError
def prefill_task(func):
"""Ensure a task's details are always in cache, so it can always be triggered via API"""
_prefill_tasks.append(
DBSystemTask(
name=func.__name__,
description=func.__doc__,
start_timestamp=now(),
finish_timestamp=now(),
status=TaskStatus.UNKNOWN,
messages=sanitize_item([_("Task has not been run yet.")]),
task_call_module=func.__module__,
task_call_func=func.__name__,
expiring=False,
duration=0,
)
)
return func
_prefill_tasks = []

View File

@ -1,49 +1,41 @@
"""Event notification tasks""" """Event notification tasks"""
from uuid import UUID
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from dramatiq.actor import actor
from guardian.shortcuts import get_anonymous_user from guardian.shortcuts import get_anonymous_user
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.expression.exceptions import PropertyMappingExpressionException
from authentik.core.models import User from authentik.core.models import User
from authentik.events.models import ( from authentik.events.models import (
Event, Event,
Notification, Notification,
NotificationRule, NotificationRule,
NotificationTransport, NotificationTransport,
NotificationTransportError,
TaskStatus,
) )
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.policies.engine import PolicyEngine from authentik.policies.engine import PolicyEngine
from authentik.policies.models import PolicyBinding, PolicyEngineMode from authentik.policies.models import PolicyBinding, PolicyEngineMode
from authentik.tasks.models import Task from authentik.root.celery import CELERY_APP
LOGGER = get_logger() LOGGER = get_logger()
@actor(description=_("Dispatch new event notifications.")) @CELERY_APP.task()
def event_trigger_dispatch(event_uuid: UUID): def event_notification_handler(event_uuid: str):
"""Start task for each trigger definition"""
for trigger in NotificationRule.objects.all(): for trigger in NotificationRule.objects.all():
event_trigger_handler.send_with_options(args=(event_uuid, trigger.name), rel_obj=trigger) event_trigger_handler.apply_async(args=[event_uuid, trigger.name], queue="authentik_events")
@actor( @CELERY_APP.task()
description=_( def event_trigger_handler(event_uuid: str, trigger_name: str):
"Check if policies attached to NotificationRule match event "
"and dispatch notification tasks."
)
)
def event_trigger_handler(event_uuid: UUID, trigger_name: str):
"""Check if policies attached to NotificationRule match event""" """Check if policies attached to NotificationRule match event"""
self: Task = CurrentTask.get_task()
event: Event = Event.objects.filter(event_uuid=event_uuid).first() event: Event = Event.objects.filter(event_uuid=event_uuid).first()
if not event: if not event:
self.warning("event doesn't exist yet or anymore", event_uuid=event_uuid) LOGGER.warning("event doesn't exist yet or anymore", event_uuid=event_uuid)
return return
trigger: NotificationRule | None = NotificationRule.objects.filter(name=trigger_name).first() trigger: NotificationRule | None = NotificationRule.objects.filter(name=trigger_name).first()
if not trigger: if not trigger:
return return
@ -78,46 +70,57 @@ def event_trigger_handler(event_uuid: UUID, trigger_name: str):
LOGGER.debug("e(trigger): event trigger matched", trigger=trigger) LOGGER.debug("e(trigger): event trigger matched", trigger=trigger)
# Create the notification objects # Create the notification objects
count = 0
for transport in trigger.transports.all(): for transport in trigger.transports.all():
for user in trigger.destination_users(event): for user in trigger.destination_users(event):
notification_transport.send_with_options( LOGGER.debug("created notification")
args=( notification_transport.apply_async(
args=[
transport.pk, transport.pk,
event.pk, str(event.pk),
user.pk, user.pk,
trigger.pk, str(trigger.pk),
), ],
rel_obj=transport, queue="authentik_events",
) )
count += 1
if transport.send_once: if transport.send_once:
break break
self.info(f"Created {count} notification tasks")
@actor(description=_("Send notification.")) @CELERY_APP.task(
def notification_transport(transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str): bind=True,
autoretry_for=(NotificationTransportError,),
retry_backoff=True,
base=SystemTask,
)
def notification_transport(
self: SystemTask, transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str
):
"""Send notification over specified transport""" """Send notification over specified transport"""
event = Event.objects.filter(pk=event_pk).first() self.save_on_success = False
if not event: try:
return event = Event.objects.filter(pk=event_pk).first()
user = User.objects.filter(pk=user_pk).first() if not event:
if not user: return
return user = User.objects.filter(pk=user_pk).first()
trigger = NotificationRule.objects.filter(pk=trigger_pk).first() if not user:
if not trigger: return
return trigger = NotificationRule.objects.filter(pk=trigger_pk).first()
notification = Notification( if not trigger:
severity=trigger.severity, body=event.summary, event=event, user=user return
) notification = Notification(
transport: NotificationTransport = NotificationTransport.objects.filter(pk=transport_pk).first() severity=trigger.severity, body=event.summary, event=event, user=user
if not transport: )
return transport = NotificationTransport.objects.filter(pk=transport_pk).first()
transport.send(notification) if not transport:
return
transport.send(notification)
self.set_status(TaskStatus.SUCCESSFUL)
except (NotificationTransportError, PropertyMappingExpressionException) as exc:
self.set_error(exc)
raise exc
@actor(description=_("Cleanup events for GDPR compliance.")) @CELERY_APP.task()
def gdpr_cleanup(user_pk: int): def gdpr_cleanup(user_pk: int):
"""cleanup events from gdpr_compliance""" """cleanup events from gdpr_compliance"""
events = Event.objects.filter(user__pk=user_pk) events = Event.objects.filter(user__pk=user_pk)
@ -125,12 +128,12 @@ def gdpr_cleanup(user_pk: int):
events.delete() events.delete()
@actor(description=_("Cleanup seen notifications and notifications whose event expired.")) @CELERY_APP.task(bind=True, base=SystemTask)
def notification_cleanup(): @prefill_task
def notification_cleanup(self: SystemTask):
"""Cleanup seen notifications and notifications whose event expired.""" """Cleanup seen notifications and notifications whose event expired."""
self: Task = CurrentTask.get_task()
notifications = Notification.objects.filter(Q(event=None) | Q(seen=True)) notifications = Notification.objects.filter(Q(event=None) | Q(seen=True))
amount = notifications.count() amount = notifications.count()
notifications.delete() notifications.delete()
LOGGER.debug("Expired notifications", amount=amount) LOGGER.debug("Expired notifications", amount=amount)
self.info(f"Expired {amount} Notifications") self.set_status(TaskStatus.SUCCESSFUL, f"Expired {amount} Notifications")

View File

@ -0,0 +1,103 @@
"""Test Monitored tasks"""
from json import loads
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.tasks import clean_expired_models
from authentik.core.tests.utils import create_test_admin_user
from authentik.events.models import SystemTask as DBSystemTask
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.lib.generators import generate_id
from authentik.root.celery import CELERY_APP
class TestSystemTasks(APITestCase):
"""Test Monitored tasks"""
def setUp(self):
super().setUp()
self.user = create_test_admin_user()
self.client.force_login(self.user)
def test_failed_successful_remove_state(self):
"""Test that a task with `save_on_success` set to `False` that failed saves
a state, and upon successful completion will delete the state"""
should_fail = True
uid = generate_id()
@CELERY_APP.task(
bind=True,
base=SystemTask,
)
def test_task(self: SystemTask):
self.save_on_success = False
self.set_uid(uid)
self.set_status(TaskStatus.ERROR if should_fail else TaskStatus.SUCCESSFUL)
# First test successful run
should_fail = False
test_task.delay().get()
self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first())
# Then test failed
should_fail = True
test_task.delay().get()
task = DBSystemTask.objects.filter(name="test_task", uid=uid).first()
self.assertEqual(task.status, TaskStatus.ERROR)
# Then after that, the state should be removed
should_fail = False
test_task.delay().get()
self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first())
def test_tasks(self):
"""Test Task API"""
clean_expired_models.delay().get()
response = self.client.get(reverse("authentik_api:systemtask-list"))
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertTrue(any(task["name"] == "clean_expired_models" for task in body["results"]))
def test_tasks_single(self):
"""Test Task API (read single)"""
clean_expired_models.delay().get()
task = DBSystemTask.objects.filter(name="clean_expired_models").first()
response = self.client.get(
reverse(
"authentik_api:systemtask-detail",
kwargs={"pk": str(task.pk)},
)
)
self.assertEqual(response.status_code, 200)
body = loads(response.content)
self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.value)
self.assertEqual(body["name"], "clean_expired_models")
response = self.client.get(
reverse("authentik_api:systemtask-detail", kwargs={"pk": "qwerqwer"})
)
self.assertEqual(response.status_code, 404)
def test_tasks_run(self):
"""Test Task API (run)"""
clean_expired_models.delay().get()
task = DBSystemTask.objects.filter(name="clean_expired_models").first()
response = self.client.post(
reverse(
"authentik_api:systemtask-run",
kwargs={"pk": str(task.pk)},
)
)
self.assertEqual(response.status_code, 204)
def test_tasks_run_404(self):
"""Test Task API (run, 404)"""
response = self.client.post(
reverse(
"authentik_api:systemtask-run",
kwargs={"pk": "qwerqewrqrqewrqewr"},
)
)
self.assertEqual(response.status_code, 404)

View File

@ -5,11 +5,13 @@ from authentik.events.api.notification_mappings import NotificationWebhookMappin
from authentik.events.api.notification_rules import NotificationRuleViewSet from authentik.events.api.notification_rules import NotificationRuleViewSet
from authentik.events.api.notification_transports import NotificationTransportViewSet from authentik.events.api.notification_transports import NotificationTransportViewSet
from authentik.events.api.notifications import NotificationViewSet from authentik.events.api.notifications import NotificationViewSet
from authentik.events.api.tasks import SystemTaskViewSet
api_urlpatterns = [ api_urlpatterns = [
("events/events", EventViewSet), ("events/events", EventViewSet),
("events/notifications", NotificationViewSet), ("events/notifications", NotificationViewSet),
("events/transports", NotificationTransportViewSet), ("events/transports", NotificationTransportViewSet),
("events/rules", NotificationRuleViewSet), ("events/rules", NotificationRuleViewSet),
("events/system_tasks", SystemTaskViewSet),
("propertymappings/notification", NotificationWebhookMappingViewSet), ("propertymappings/notification", NotificationWebhookMappingViewSet),
] ]

View File

@ -41,7 +41,6 @@ REDIS_ENV_KEYS = [
# Old key -> new key # Old key -> new key
DEPRECATIONS = { DEPRECATIONS = {
"geoip": "events.context_processors.geoip", "geoip": "events.context_processors.geoip",
"worker.concurrency": "worker.processes",
"redis.broker_url": "broker.url", "redis.broker_url": "broker.url",
"redis.broker_transport_options": "broker.transport_options", "redis.broker_transport_options": "broker.transport_options",
"redis.cache_timeout": "cache.timeout", "redis.cache_timeout": "cache.timeout",

View File

@ -21,10 +21,6 @@ def start_debug_server(**kwargs) -> bool:
listen: str = CONFIG.get("listen.listen_debug_py", "127.0.0.1:9901") listen: str = CONFIG.get("listen.listen_debug_py", "127.0.0.1:9901")
host, _, port = listen.rpartition(":") host, _, port = listen.rpartition(":")
try: debugpy.listen((host, int(port)), **kwargs) # nosec
debugpy.listen((host, int(port)), **kwargs) # nosec
except RuntimeError:
LOGGER.warning("Could not start debug server. Continuing without")
return False
LOGGER.debug("Starting debug server", host=host, port=port) LOGGER.debug("Starting debug server", host=host, port=port)
return True return True

View File

@ -8,9 +8,9 @@
# make gen-dev-config # make gen-dev-config
# ``` # ```
# #
# You may edit the generated file to override the configuration below. # You may edit the generated file to override the configuration below.
# #
# When making modifying the default configuration file, # When making modifying the default configuration file,
# ensure that the corresponding documentation is updated to match. # ensure that the corresponding documentation is updated to match.
# #
# @see {@link ../../website/docs/install-config/configuration/configuration.mdx Configuration documentation} for more information. # @see {@link ../../website/docs/install-config/configuration/configuration.mdx Configuration documentation} for more information.
@ -157,14 +157,7 @@ web:
path: / path: /
worker: worker:
processes: 2 concurrency: 2
threads: 1
consumer_listen_timeout: "seconds=30"
task_max_retries: 20
task_default_time_limit: "minutes=10"
task_purge_interval: "days=1"
task_expiration: "days=30"
scheduler_interval: "seconds=60"
storage: storage:
media: media:

View File

@ -88,6 +88,7 @@ def get_logger_config():
"authentik": global_level, "authentik": global_level,
"django": "WARNING", "django": "WARNING",
"django.request": "ERROR", "django.request": "ERROR",
"celery": "WARNING",
"selenium": "WARNING", "selenium": "WARNING",
"docker": "WARNING", "docker": "WARNING",
"urllib3": "WARNING", "urllib3": "WARNING",

View File

@ -3,6 +3,8 @@
from asyncio.exceptions import CancelledError from asyncio.exceptions import CancelledError
from typing import Any from typing import Any
from billiard.exceptions import SoftTimeLimitExceeded, WorkerLostError
from celery.exceptions import CeleryError
from channels_redis.core import ChannelFull from channels_redis.core import ChannelFull
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError
@ -20,6 +22,7 @@ from sentry_sdk import HttpTransport, get_current_scope
from sentry_sdk import init as sentry_sdk_init from sentry_sdk import init as sentry_sdk_init
from sentry_sdk.api import set_tag from sentry_sdk.api import set_tag
from sentry_sdk.integrations.argv import ArgvIntegration from sentry_sdk.integrations.argv import ArgvIntegration
from sentry_sdk.integrations.celery import CeleryIntegration
from sentry_sdk.integrations.django import DjangoIntegration from sentry_sdk.integrations.django import DjangoIntegration
from sentry_sdk.integrations.redis import RedisIntegration from sentry_sdk.integrations.redis import RedisIntegration
from sentry_sdk.integrations.socket import SocketIntegration from sentry_sdk.integrations.socket import SocketIntegration
@ -68,6 +71,10 @@ ignored_classes = (
LocalProtocolError, LocalProtocolError,
# rest_framework error # rest_framework error
APIException, APIException,
# celery errors
WorkerLostError,
CeleryError,
SoftTimeLimitExceeded,
# custom baseclass # custom baseclass
SentryIgnoredException, SentryIgnoredException,
# ldap errors # ldap errors
@ -108,6 +115,7 @@ def sentry_init(**sentry_init_kwargs):
ArgvIntegration(), ArgvIntegration(),
StdlibIntegration(), StdlibIntegration(),
DjangoIntegration(transaction_style="function_name", cache_spans=True), DjangoIntegration(transaction_style="function_name", cache_spans=True),
CeleryIntegration(),
RedisIntegration(), RedisIntegration(),
ThreadingIntegration(propagate_hub=True), ThreadingIntegration(propagate_hub=True),
SocketIntegration(), SocketIntegration(),
@ -152,11 +160,14 @@ def before_send(event: dict, hint: dict) -> dict | None:
return None return None
if "logger" in event: if "logger" in event:
if event["logger"] in [ if event["logger"] in [
"kombu",
"asyncio", "asyncio",
"multiprocessing", "multiprocessing",
"django_redis", "django_redis",
"django.security.DisallowedHost", "django.security.DisallowedHost",
"django_redis.cache", "django_redis.cache",
"celery.backends.redis",
"celery.worker",
"paramiko.transport", "paramiko.transport",
]: ]:
return None return None

View File

@ -1,12 +0,0 @@
from rest_framework.fields import BooleanField, ChoiceField, DateTimeField
from authentik.core.api.utils import PassiveSerializer
from authentik.tasks.models import TaskStatus
class SyncStatusSerializer(PassiveSerializer):
"""Provider/source sync status"""
is_running = BooleanField()
last_successful_sync = DateTimeField(required=False)
last_sync_status = ChoiceField(required=False, choices=TaskStatus.choices)

View File

@ -1,7 +1,7 @@
"""Sync constants""" """Sync constants"""
PAGE_SIZE = 100 PAGE_SIZE = 100
PAGE_TIMEOUT_MS = 60 * 60 * 0.5 * 1000 # Half an hour PAGE_TIMEOUT = 60 * 60 * 0.5 # Half an hour
HTTP_CONFLICT = 409 HTTP_CONFLICT = 409
HTTP_NO_CONTENT = 204 HTTP_NO_CONTENT = 204
HTTP_SERVICE_UNAVAILABLE = 503 HTTP_SERVICE_UNAVAILABLE = 503

View File

@ -1,5 +1,7 @@
from dramatiq.actor import Actor from celery import Task
from drf_spectacular.utils import extend_schema from django.utils.text import slugify
from drf_spectacular.utils import OpenApiResponse, extend_schema
from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.fields import BooleanField, CharField, ChoiceField from rest_framework.fields import BooleanField, CharField, ChoiceField
from rest_framework.request import Request from rest_framework.request import Request
@ -7,12 +9,18 @@ from rest_framework.response import Response
from authentik.core.api.utils import ModelSerializer, PassiveSerializer from authentik.core.api.utils import ModelSerializer, PassiveSerializer
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.events.logs import LogEventSerializer from authentik.events.api.tasks import SystemTaskSerializer
from authentik.lib.sync.api import SyncStatusSerializer from authentik.events.logs import LogEvent, LogEventSerializer
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path from authentik.lib.utils.reflection import class_to_path
from authentik.rbac.filters import ObjectFilter from authentik.rbac.filters import ObjectFilter
from authentik.tasks.models import Task, TaskStatus
class SyncStatusSerializer(PassiveSerializer):
"""Provider sync status"""
is_running = BooleanField(read_only=True)
tasks = SystemTaskSerializer(many=True, read_only=True)
class SyncObjectSerializer(PassiveSerializer): class SyncObjectSerializer(PassiveSerializer):
@ -37,10 +45,15 @@ class SyncObjectResultSerializer(PassiveSerializer):
class OutgoingSyncProviderStatusMixin: class OutgoingSyncProviderStatusMixin:
"""Common API Endpoints for Outgoing sync providers""" """Common API Endpoints for Outgoing sync providers"""
sync_task: Actor sync_single_task: type[Task] = None
sync_objects_task: Actor sync_objects_task: type[Task] = None
@extend_schema(responses={200: SyncStatusSerializer()}) @extend_schema(
responses={
200: SyncStatusSerializer(),
404: OpenApiResponse(description="Task not found"),
}
)
@action( @action(
methods=["GET"], methods=["GET"],
detail=True, detail=True,
@ -51,39 +64,18 @@ class OutgoingSyncProviderStatusMixin:
def sync_status(self, request: Request, pk: int) -> Response: def sync_status(self, request: Request, pk: int) -> Response:
"""Get provider's sync status""" """Get provider's sync status"""
provider: OutgoingSyncProvider = self.get_object() provider: OutgoingSyncProvider = self.get_object()
tasks = list(
status = {} get_objects_for_user(request.user, "authentik_events.view_systemtask").filter(
name=self.sync_single_task.__name__,
with provider.sync_lock as lock_acquired: uid=slugify(provider.name),
# If we could not acquire the lock, it means a task is using it, and thus is running
status["is_running"] = not lock_acquired
sync_schedule = None
for schedule in provider.schedules.all():
if schedule.actor_name == self.sync_task.actor_name:
sync_schedule = schedule
if not sync_schedule:
return Response(SyncStatusSerializer(status).data)
last_task: Task = (
sync_schedule.tasks.exclude(
aggregated_status__in=(TaskStatus.CONSUMED, TaskStatus.QUEUED)
) )
.order_by("-mtime")
.first()
) )
last_successful_task: Task = ( with provider.sync_lock as lock_acquired:
sync_schedule.tasks.filter(aggregated_status__in=(TaskStatus.DONE, TaskStatus.INFO)) status = {
.order_by("-mtime") "tasks": tasks,
.first() # If we could not acquire the lock, it means a task is using it, and thus is running
) "is_running": not lock_acquired,
}
if last_task:
status["last_sync_status"] = last_task.aggregated_status
if last_successful_task:
status["last_successful_sync"] = last_successful_task.mtime
return Response(SyncStatusSerializer(status).data) return Response(SyncStatusSerializer(status).data)
@extend_schema( @extend_schema(
@ -102,20 +94,14 @@ class OutgoingSyncProviderStatusMixin:
provider: OutgoingSyncProvider = self.get_object() provider: OutgoingSyncProvider = self.get_object()
params = SyncObjectSerializer(data=request.data) params = SyncObjectSerializer(data=request.data)
params.is_valid(raise_exception=True) params.is_valid(raise_exception=True)
msg = self.sync_objects_task.send_with_options( res: list[LogEvent] = self.sync_objects_task.delay(
kwargs={ params.validated_data["sync_object_model"],
"object_type": params.validated_data["sync_object_model"], page=1,
"page": 1, provider_pk=provider.pk,
"provider_pk": provider.pk, pk=params.validated_data["sync_object_id"],
"override_dry_run": params.validated_data["override_dry_run"], override_dry_run=params.validated_data["override_dry_run"],
"pk": params.validated_data["sync_object_id"], ).get()
}, return Response(SyncObjectResultSerializer(instance={"messages": res}).data)
rel_obj=provider,
)
msg.get_result(block=True)
task: Task = msg.options["task"]
task.refresh_from_db()
return Response(SyncObjectResultSerializer(instance={"messages": task._messages}).data)
class OutgoingSyncConnectionCreateMixin: class OutgoingSyncConnectionCreateMixin:

View File

@ -1,18 +1,12 @@
from typing import Any, Self from typing import Any, Self
import pglock import pglock
from django.core.paginator import Paginator
from django.db import connection, models from django.db import connection, models
from django.db.models import Model, QuerySet, TextChoices from django.db.models import Model, QuerySet, TextChoices
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from dramatiq.actor import Actor
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT_MS
from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient
from authentik.lib.utils.time import fqdn_rand
from authentik.tasks.schedules.lib import ScheduleSpec
from authentik.tasks.schedules.models import ScheduledModel
class OutgoingSyncDeleteAction(TextChoices): class OutgoingSyncDeleteAction(TextChoices):
@ -24,7 +18,7 @@ class OutgoingSyncDeleteAction(TextChoices):
SUSPEND = "suspend" SUSPEND = "suspend"
class OutgoingSyncProvider(ScheduledModel, Model): class OutgoingSyncProvider(Model):
"""Base abstract models for providers implementing outgoing sync""" """Base abstract models for providers implementing outgoing sync"""
dry_run = models.BooleanField( dry_run = models.BooleanField(
@ -45,19 +39,6 @@ class OutgoingSyncProvider(ScheduledModel, Model):
def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]: def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]:
raise NotImplementedError raise NotImplementedError
def get_paginator[T: User | Group](self, type: type[T]) -> Paginator:
return Paginator(self.get_object_qs(type), PAGE_SIZE)
def get_object_sync_time_limit_ms[T: User | Group](self, type: type[T]) -> int:
num_pages: int = self.get_paginator(type).num_pages
return int(num_pages * PAGE_TIMEOUT_MS * 1.5)
def get_sync_time_limit_ms(self) -> int:
return int(
(self.get_object_sync_time_limit_ms(User) + self.get_object_sync_time_limit_ms(Group))
* 1.5
)
@property @property
def sync_lock(self) -> pglock.advisory: def sync_lock(self) -> pglock.advisory:
"""Postgres lock for syncing to prevent multiple parallel syncs happening""" """Postgres lock for syncing to prevent multiple parallel syncs happening"""
@ -66,22 +47,3 @@ class OutgoingSyncProvider(ScheduledModel, Model):
timeout=0, timeout=0,
side_effect=pglock.Return, side_effect=pglock.Return,
) )
@property
def sync_actor(self) -> Actor:
raise NotImplementedError
@property
def schedule_specs(self) -> list[ScheduleSpec]:
return [
ScheduleSpec(
actor=self.sync_actor,
uid=self.pk,
args=(self.pk,),
options={
"time_limit": self.get_sync_time_limit_ms(),
},
send_on_save=True,
crontab=f"{fqdn_rand(self.pk)} */4 * * *",
),
]

View File

@ -1,8 +1,12 @@
from collections.abc import Callable
from django.core.paginator import Paginator
from django.db.models import Model from django.db.models import Model
from django.db.models.query import Q
from django.db.models.signals import m2m_changed, post_save, pre_delete from django.db.models.signals import m2m_changed, post_save, pre_delete
from dramatiq.actor import Actor
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
from authentik.lib.sync.outgoing.base import Direction from authentik.lib.sync.outgoing.base import Direction
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path from authentik.lib.utils.reflection import class_to_path
@ -10,30 +14,45 @@ from authentik.lib.utils.reflection import class_to_path
def register_signals( def register_signals(
provider_type: type[OutgoingSyncProvider], provider_type: type[OutgoingSyncProvider],
task_sync_direct_dispatch: Actor[[str, str | int, str], None], task_sync_single: Callable[[int], None],
task_sync_m2m_dispatch: Actor[[str, str, list[str], bool], None], task_sync_direct: Callable[[int], None],
task_sync_m2m: Callable[[int], None],
): ):
"""Register sync signals""" """Register sync signals"""
uid = class_to_path(provider_type) uid = class_to_path(provider_type)
def post_save_provider(sender: type[Model], instance: OutgoingSyncProvider, created: bool, **_):
"""Trigger sync when Provider is saved"""
users_paginator = Paginator(instance.get_object_qs(User), PAGE_SIZE)
groups_paginator = Paginator(instance.get_object_qs(Group), PAGE_SIZE)
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
time_limit = soft_time_limit * 1.5
task_sync_single.apply_async(
(instance.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
)
post_save.connect(post_save_provider, provider_type, dispatch_uid=uid, weak=False)
def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_): def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_):
"""Post save handler""" """Post save handler"""
task_sync_direct_dispatch.send( if not provider_type.objects.filter(
class_to_path(instance.__class__), Q(backchannel_application__isnull=False) | Q(application__isnull=False)
instance.pk, ).exists():
Direction.add.value, return
) task_sync_direct.delay(class_to_path(instance.__class__), instance.pk, Direction.add.value)
post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False) post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False)
post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False) post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False)
def model_pre_delete(sender: type[Model], instance: User | Group, **_): def model_pre_delete(sender: type[Model], instance: User | Group, **_):
"""Pre-delete handler""" """Pre-delete handler"""
task_sync_direct_dispatch.send( if not provider_type.objects.filter(
class_to_path(instance.__class__), Q(backchannel_application__isnull=False) | Q(application__isnull=False)
instance.pk, ).exists():
Direction.remove.value, return
) task_sync_direct.delay(
class_to_path(instance.__class__), instance.pk, Direction.remove.value
).get(propagate=False)
pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False) pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False)
pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False) pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False)
@ -44,6 +63,16 @@ def register_signals(
"""Sync group membership""" """Sync group membership"""
if action not in ["post_add", "post_remove"]: if action not in ["post_add", "post_remove"]:
return return
task_sync_m2m_dispatch.send(instance.pk, action, list(pk_set), reverse) if not provider_type.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
).exists():
return
# reverse: instance is a Group, pk_set is a list of user pks
# non-reverse: instance is a User, pk_set is a list of groups
if reverse:
task_sync_m2m.delay(str(instance.pk), action, list(pk_set))
else:
for group_pk in pk_set:
task_sync_m2m.delay(group_pk, action, [instance.pk])
m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False) m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False)

View File

@ -1,17 +1,23 @@
from collections.abc import Callable
from dataclasses import asdict
from celery import group
from celery.exceptions import Retry
from celery.result import allow_join_result
from django.core.paginator import Paginator from django.core.paginator import Paginator
from django.db.models import Model, QuerySet from django.db.models import Model, QuerySet
from django.db.models.query import Q from django.db.models.query import Q
from django.utils.text import slugify from django.utils.text import slugify
from django_dramatiq_postgres.middleware import CurrentTask from django.utils.translation import gettext_lazy as _
from dramatiq.actor import Actor
from dramatiq.composition import group
from dramatiq.errors import Retry
from structlog.stdlib import BoundLogger, get_logger from structlog.stdlib import BoundLogger, get_logger
from authentik.core.expression.exceptions import SkipObjectException from authentik.core.expression.exceptions import SkipObjectException
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.events.logs import LogEvent
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask
from authentik.events.utils import sanitize_item from authentik.events.utils import sanitize_item
from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT_MS from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT
from authentik.lib.sync.outgoing.base import Direction from authentik.lib.sync.outgoing.base import Direction
from authentik.lib.sync.outgoing.exceptions import ( from authentik.lib.sync.outgoing.exceptions import (
BadRequestSyncException, BadRequestSyncException,
@ -21,12 +27,11 @@ from authentik.lib.sync.outgoing.exceptions import (
) )
from authentik.lib.sync.outgoing.models import OutgoingSyncProvider from authentik.lib.sync.outgoing.models import OutgoingSyncProvider
from authentik.lib.utils.reflection import class_to_path, path_to_class from authentik.lib.utils.reflection import class_to_path, path_to_class
from authentik.tasks.models import Task
class SyncTasks: class SyncTasks:
"""Container for all sync 'tasks' (this class doesn't actually contain """Container for all sync 'tasks' (this class doesn't actually contain celery
tasks due to dramatiq's magic, however exposes a number of functions to be called from tasks)""" tasks due to celery's magic, however exposes a number of functions to be called from tasks)"""
logger: BoundLogger logger: BoundLogger
@ -34,104 +39,107 @@ class SyncTasks:
super().__init__() super().__init__()
self._provider_model = provider_model self._provider_model = provider_model
def sync_paginator( def sync_all(self, single_sync: Callable[[int], None]):
self, for provider in self._provider_model.objects.filter(
current_task: Task, Q(backchannel_application__isnull=False) | Q(application__isnull=False)
provider: OutgoingSyncProvider, ):
sync_objects: Actor[[str, int, int, bool], None], self.trigger_single_task(provider, single_sync)
paginator: Paginator,
object_type: type[User | Group],
**options,
):
tasks = []
for page in paginator.page_range:
page_sync = sync_objects.message_with_options(
args=(class_to_path(object_type), page, provider.pk),
time_limit=PAGE_TIMEOUT_MS,
# Assign tasks to the same schedule as the current one
rel_obj=current_task.rel_obj,
**options,
)
tasks.append(page_sync)
return tasks
def sync( def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]):
"""Wrapper single sync task that correctly sets time limits based
on the amount of objects that will be synced"""
users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
time_limit = soft_time_limit * 1.5
return sync_task.apply_async(
(provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
)
def sync_single(
self, self,
task: SystemTask,
provider_pk: int, provider_pk: int,
sync_objects: Actor[[str, int, int, bool], None], sync_objects: Callable[[int, int], list[str]],
): ):
task: Task = CurrentTask.get_task()
self.logger = get_logger().bind( self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model), provider_type=class_to_path(self._provider_model),
provider_pk=provider_pk, provider_pk=provider_pk,
) )
provider: OutgoingSyncProvider = self._provider_model.objects.filter( provider = self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False), Q(backchannel_application__isnull=False) | Q(application__isnull=False),
pk=provider_pk, pk=provider_pk,
).first() ).first()
if not provider: if not provider:
task.warning("No provider found. Is it assigned to an application?")
return return
task.set_uid(slugify(provider.name)) task.set_uid(slugify(provider.name))
task.info("Starting full provider sync") messages = []
messages.append(_("Starting full provider sync"))
self.logger.debug("Starting provider sync") self.logger.debug("Starting provider sync")
with provider.sync_lock as lock_acquired: users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE)
groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE)
with allow_join_result(), provider.sync_lock as lock_acquired:
if not lock_acquired: if not lock_acquired:
task.info("Synchronization is already running. Skipping.")
self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name) self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name)
return return
try: try:
users_tasks = group( messages.append(_("Syncing users"))
self.sync_paginator( user_results = (
current_task=task, group(
provider=provider, [
sync_objects=sync_objects, sync_objects.signature(
paginator=provider.get_paginator(User), args=(class_to_path(User), page, provider_pk),
object_type=User, time_limit=PAGE_TIMEOUT,
soft_time_limit=PAGE_TIMEOUT,
)
for page in users_paginator.page_range
]
) )
.apply_async()
.get()
) )
group_tasks = group( for result in user_results:
self.sync_paginator( for msg in result:
current_task=task, messages.append(LogEvent(**msg))
provider=provider, messages.append(_("Syncing groups"))
sync_objects=sync_objects, group_results = (
paginator=provider.get_paginator(Group), group(
object_type=Group, [
sync_objects.signature(
args=(class_to_path(Group), page, provider_pk),
time_limit=PAGE_TIMEOUT,
soft_time_limit=PAGE_TIMEOUT,
)
for page in groups_paginator.page_range
]
) )
.apply_async()
.get()
) )
users_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(User)) for result in group_results:
group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group)) for msg in result:
messages.append(LogEvent(**msg))
except TransientSyncException as exc: except TransientSyncException as exc:
self.logger.warning("transient sync exception", exc=exc) self.logger.warning("transient sync exception", exc=exc)
task.warning("Sync encountered a transient exception. Retrying", exc=exc) raise task.retry(exc=exc) from exc
raise Retry() from exc
except StopSync as exc: except StopSync as exc:
task.error(exc) task.set_error(exc)
return return
task.set_status(TaskStatus.SUCCESSFUL, *messages)
def sync_objects( def sync_objects(
self, self, object_type: str, page: int, provider_pk: int, override_dry_run=False, **filter
object_type: str,
page: int,
provider_pk: int,
override_dry_run=False,
**filter,
): ):
task: Task = CurrentTask.get_task()
_object_type: type[Model] = path_to_class(object_type) _object_type: type[Model] = path_to_class(object_type)
self.logger = get_logger().bind( self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model), provider_type=class_to_path(self._provider_model),
provider_pk=provider_pk, provider_pk=provider_pk,
object_type=object_type, object_type=object_type,
) )
provider: OutgoingSyncProvider = self._provider_model.objects.filter( messages = []
Q(backchannel_application__isnull=False) | Q(application__isnull=False), provider = self._provider_model.objects.filter(pk=provider_pk).first()
pk=provider_pk,
).first()
if not provider: if not provider:
task.warning("No provider found. Is it assigned to an application?") return messages
return
task.set_uid(slugify(provider.name))
# Override dry run mode if requested, however don't save the provider # Override dry run mode if requested, however don't save the provider
# so that scheduled sync tasks still run in dry_run mode # so that scheduled sync tasks still run in dry_run mode
if override_dry_run: if override_dry_run:
@ -139,13 +147,25 @@ class SyncTasks:
try: try:
client = provider.client_for_model(_object_type) client = provider.client_for_model(_object_type)
except TransientSyncException: except TransientSyncException:
return return messages
paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE) paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE)
if client.can_discover: if client.can_discover:
self.logger.debug("starting discover") self.logger.debug("starting discover")
client.discover() client.discover()
self.logger.debug("starting sync for page", page=page) self.logger.debug("starting sync for page", page=page)
task.info(f"Syncing page {page} or {_object_type._meta.verbose_name_plural}") messages.append(
asdict(
LogEvent(
_(
"Syncing page {page} of {object_type}".format(
page=page, object_type=_object_type._meta.verbose_name_plural
)
),
log_level="info",
logger=f"{provider._meta.verbose_name}@{object_type}",
)
)
)
for obj in paginator.page(page).object_list: for obj in paginator.page(page).object_list:
obj: Model obj: Model
try: try:
@ -154,58 +174,89 @@ class SyncTasks:
self.logger.debug("skipping object due to SkipObject", obj=obj) self.logger.debug("skipping object due to SkipObject", obj=obj)
continue continue
except DryRunRejected as exc: except DryRunRejected as exc:
task.info( messages.append(
"Dropping mutating request due to dry run", asdict(
obj=sanitize_item(obj), LogEvent(
method=exc.method, _("Dropping mutating request due to dry run"),
url=exc.url, log_level="info",
body=exc.body, logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={
"obj": sanitize_item(obj),
"method": exc.method,
"url": exc.url,
"body": exc.body,
},
)
)
) )
except BadRequestSyncException as exc: except BadRequestSyncException as exc:
self.logger.warning("failed to sync object", exc=exc, obj=obj) self.logger.warning("failed to sync object", exc=exc, obj=obj)
task.warning( messages.append(
f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to error: {str(exc)}", asdict(
arguments=exc.args[1:], LogEvent(
obj=sanitize_item(obj), _(
(
"Failed to sync {object_type} {object_name} "
"due to error: {error}"
).format_map(
{
"object_type": obj._meta.verbose_name,
"object_name": str(obj),
"error": str(exc),
}
)
),
log_level="warning",
logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)},
)
)
) )
except TransientSyncException as exc: except TransientSyncException as exc:
self.logger.warning("failed to sync object", exc=exc, user=obj) self.logger.warning("failed to sync object", exc=exc, user=obj)
task.warning( messages.append(
f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to " asdict(
"transient error: {str(exc)}", LogEvent(
obj=sanitize_item(obj), _(
(
"Failed to sync {object_type} {object_name} "
"due to transient error: {error}"
).format_map(
{
"object_type": obj._meta.verbose_name,
"object_name": str(obj),
"error": str(exc),
}
)
),
log_level="warning",
logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={"obj": sanitize_item(obj)},
)
)
) )
except StopSync as exc: except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc) self.logger.warning("Stopping sync", exc=exc)
task.warning( messages.append(
f"Stopping sync due to error: {exc.detail()}", asdict(
obj=sanitize_item(obj), LogEvent(
_(
"Stopping sync due to error: {error}".format_map(
{
"error": exc.detail(),
}
)
),
log_level="warning",
logger=f"{provider._meta.verbose_name}@{object_type}",
attributes={"obj": sanitize_item(obj)},
)
)
) )
break break
return messages
def sync_signal_direct_dispatch( def sync_signal_direct(self, model: str, pk: str | int, raw_op: str):
self,
task_sync_signal_direct: Actor[[str, str | int, int, str], None],
model: str,
pk: str | int,
raw_op: str,
):
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
task_sync_signal_direct.send_with_options(
args=(model, pk, provider.pk, raw_op),
rel_obj=provider,
)
def sync_signal_direct(
self,
model: str,
pk: str | int,
provider_pk: int,
raw_op: str,
):
task: Task = CurrentTask.get_task()
self.logger = get_logger().bind( self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model), provider_type=class_to_path(self._provider_model),
) )
@ -213,108 +264,65 @@ class SyncTasks:
instance = model_class.objects.filter(pk=pk).first() instance = model_class.objects.filter(pk=pk).first()
if not instance: if not instance:
return return
provider: OutgoingSyncProvider = self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
pk=provider_pk,
).first()
if not provider:
task.warning("No provider found. Is it assigned to an application?")
return
task.set_uid(slugify(provider.name))
operation = Direction(raw_op) operation = Direction(raw_op)
client = provider.client_for_model(instance.__class__)
# Check if the object is allowed within the provider's restrictions
queryset = provider.get_object_qs(instance.__class__)
if not queryset:
return
# The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider
if not queryset.filter(pk=instance.pk).exists():
return
try:
if operation == Direction.add:
client.write(instance)
if operation == Direction.remove:
client.delete(instance)
except TransientSyncException as exc:
raise Retry() from exc
except SkipObjectException:
return
except DryRunRejected as exc:
self.logger.info("Rejected dry-run event", exc=exc)
except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
def sync_signal_m2m_dispatch(
self,
task_sync_signal_m2m: Actor[[str, int, str, list[int]], None],
instance_pk: str,
action: str,
pk_set: list[int],
reverse: bool,
):
for provider in self._provider_model.objects.filter( for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False) Q(backchannel_application__isnull=False) | Q(application__isnull=False)
): ):
# reverse: instance is a Group, pk_set is a list of user pks client = provider.client_for_model(instance.__class__)
# non-reverse: instance is a User, pk_set is a list of groups # Check if the object is allowed within the provider's restrictions
if reverse: queryset = provider.get_object_qs(instance.__class__)
task_sync_signal_m2m.send_with_options( if not queryset:
args=(instance_pk, provider.pk, action, list(pk_set)), continue
rel_obj=provider,
)
else:
for pk in pk_set:
task_sync_signal_m2m.send_with_options(
args=(pk, provider.pk, action, [instance_pk]),
rel_obj=provider,
)
def sync_signal_m2m( # The queryset we get from the provider must include the instance we've got given
self, # otherwise ignore this provider
group_pk: str, if not queryset.filter(pk=instance.pk).exists():
provider_pk: int, continue
action: str,
pk_set: list[int], try:
): if operation == Direction.add:
task: Task = CurrentTask.get_task() client.write(instance)
if operation == Direction.remove:
client.delete(instance)
except TransientSyncException as exc:
raise Retry() from exc
except SkipObjectException:
continue
except DryRunRejected as exc:
self.logger.info("Rejected dry-run event", exc=exc)
except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]):
self.logger = get_logger().bind( self.logger = get_logger().bind(
provider_type=class_to_path(self._provider_model), provider_type=class_to_path(self._provider_model),
) )
group = Group.objects.filter(pk=group_pk).first() group = Group.objects.filter(pk=group_pk).first()
if not group: if not group:
return return
provider: OutgoingSyncProvider = self._provider_model.objects.filter( for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False), Q(backchannel_application__isnull=False) | Q(application__isnull=False)
pk=provider_pk, ):
).first() # Check if the object is allowed within the provider's restrictions
if not provider: queryset: QuerySet = provider.get_object_qs(Group)
task.warning("No provider found. Is it assigned to an application?") # The queryset we get from the provider must include the instance we've got given
return # otherwise ignore this provider
task.set_uid(slugify(provider.name)) if not queryset.filter(pk=group_pk).exists():
continue
# Check if the object is allowed within the provider's restrictions client = provider.client_for_model(Group)
queryset: QuerySet = provider.get_object_qs(Group) try:
# The queryset we get from the provider must include the instance we've got given operation = None
# otherwise ignore this provider if action == "post_add":
if not queryset.filter(pk=group_pk).exists(): operation = Direction.add
return if action == "post_remove":
operation = Direction.remove
client = provider.client_for_model(Group) client.update_group(group, operation, pk_set)
try: except TransientSyncException as exc:
operation = None raise Retry() from exc
if action == "post_add": except SkipObjectException:
operation = Direction.add continue
if action == "post_remove": except DryRunRejected as exc:
operation = Direction.remove self.logger.info("Rejected dry-run event", exc=exc)
client.update_group(group, operation, pk_set) except StopSync as exc:
except TransientSyncException as exc: self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)
raise Retry() from exc
except SkipObjectException:
return
except DryRunRejected as exc:
self.logger.info("Rejected dry-run event", exc=exc)
except StopSync as exc:
self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk)

View File

@ -5,8 +5,6 @@ from structlog.stdlib import get_logger
from authentik.blueprints.apps import ManagedAppConfig from authentik.blueprints.apps import ManagedAppConfig
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand
from authentik.tasks.schedules.lib import ScheduleSpec
LOGGER = get_logger() LOGGER = get_logger()
@ -62,27 +60,3 @@ class AuthentikOutpostConfig(ManagedAppConfig):
outpost.save() outpost.save()
else: else:
Outpost.objects.filter(managed=MANAGED_OUTPOST).delete() Outpost.objects.filter(managed=MANAGED_OUTPOST).delete()
@property
def tenant_schedule_specs(self) -> list[ScheduleSpec]:
from authentik.outposts.tasks import outpost_token_ensurer
return [
ScheduleSpec(
actor=outpost_token_ensurer,
crontab=f"{fqdn_rand('outpost_token_ensurer')} */8 * * *",
),
]
@property
def global_schedule_specs(self) -> list[ScheduleSpec]:
from authentik.outposts.tasks import outpost_connection_discovery
return [
ScheduleSpec(
actor=outpost_connection_discovery,
crontab=f"{fqdn_rand('outpost_connection_discovery')} */8 * * *",
send_on_startup=True,
paused=not CONFIG.get_bool("outposts.discover"),
),
]

View File

@ -101,13 +101,7 @@ class KubernetesController(BaseController):
all_logs = [] all_logs = []
for reconcile_key in self.reconcile_order: for reconcile_key in self.reconcile_order:
if reconcile_key in self.outpost.config.kubernetes_disabled_components: if reconcile_key in self.outpost.config.kubernetes_disabled_components:
all_logs.append( all_logs += [f"{reconcile_key.title()}: Disabled"]
LogEvent(
log_level="info",
event=f"{reconcile_key.title()}: Disabled",
logger=str(type(self)),
)
)
continue continue
with capture_logs() as logs: with capture_logs() as logs:
reconciler_cls = self.reconcilers.get(reconcile_key) reconciler_cls = self.reconcilers.get(reconcile_key)
@ -140,13 +134,7 @@ class KubernetesController(BaseController):
all_logs = [] all_logs = []
for reconcile_key in self.reconcile_order: for reconcile_key in self.reconcile_order:
if reconcile_key in self.outpost.config.kubernetes_disabled_components: if reconcile_key in self.outpost.config.kubernetes_disabled_components:
all_logs.append( all_logs += [f"{reconcile_key.title()}: Disabled"]
LogEvent(
log_level="info",
event=f"{reconcile_key.title()}: Disabled",
logger=str(type(self)),
)
)
continue continue
with capture_logs() as logs: with capture_logs() as logs:
reconciler_cls = self.reconcilers.get(reconcile_key) reconciler_cls = self.reconcilers.get(reconcile_key)

View File

@ -36,10 +36,7 @@ from authentik.lib.config import CONFIG
from authentik.lib.models import InheritanceForeignKey, SerializerModel from authentik.lib.models import InheritanceForeignKey, SerializerModel
from authentik.lib.sentry import SentryIgnoredException from authentik.lib.sentry import SentryIgnoredException
from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.time import fqdn_rand
from authentik.outposts.controllers.k8s.utils import get_namespace from authentik.outposts.controllers.k8s.utils import get_namespace
from authentik.tasks.schedules.lib import ScheduleSpec
from authentik.tasks.schedules.models import ScheduledModel
OUR_VERSION = parse(__version__) OUR_VERSION = parse(__version__)
OUTPOST_HELLO_INTERVAL = 10 OUTPOST_HELLO_INTERVAL = 10
@ -118,7 +115,7 @@ class OutpostServiceConnectionState:
healthy: bool healthy: bool
class OutpostServiceConnection(ScheduledModel, models.Model): class OutpostServiceConnection(models.Model):
"""Connection details for an Outpost Controller, like Docker or Kubernetes""" """Connection details for an Outpost Controller, like Docker or Kubernetes"""
uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True)
@ -148,11 +145,11 @@ class OutpostServiceConnection(ScheduledModel, models.Model):
@property @property
def state(self) -> OutpostServiceConnectionState: def state(self) -> OutpostServiceConnectionState:
"""Get state of service connection""" """Get state of service connection"""
from authentik.outposts.tasks import outpost_service_connection_monitor from authentik.outposts.tasks import outpost_service_connection_state
state = cache.get(self.state_key, None) state = cache.get(self.state_key, None)
if not state: if not state:
outpost_service_connection_monitor.send_with_options(args=(self.pk), rel_obj=self) outpost_service_connection_state.delay(self.pk)
return OutpostServiceConnectionState("", False) return OutpostServiceConnectionState("", False)
return state return state
@ -163,20 +160,6 @@ class OutpostServiceConnection(ScheduledModel, models.Model):
# since the response doesn't use the correct inheritance # since the response doesn't use the correct inheritance
return "" return ""
@property
def schedule_specs(self) -> list[ScheduleSpec]:
from authentik.outposts.tasks import outpost_service_connection_monitor
return [
ScheduleSpec(
actor=outpost_service_connection_monitor,
uid=self.pk,
args=(self.pk,),
crontab="3-59/15 * * * *",
send_on_save=True,
),
]
class DockerServiceConnection(SerializerModel, OutpostServiceConnection): class DockerServiceConnection(SerializerModel, OutpostServiceConnection):
"""Service Connection to a Docker endpoint""" """Service Connection to a Docker endpoint"""
@ -261,7 +244,7 @@ class KubernetesServiceConnection(SerializerModel, OutpostServiceConnection):
return "ak-service-connection-kubernetes-form" return "ak-service-connection-kubernetes-form"
class Outpost(ScheduledModel, SerializerModel, ManagedModel): class Outpost(SerializerModel, ManagedModel):
"""Outpost instance which manages a service user and token""" """Outpost instance which manages a service user and token"""
uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True)
@ -315,21 +298,6 @@ class Outpost(ScheduledModel, SerializerModel, ManagedModel):
"""Username for service user""" """Username for service user"""
return f"ak-outpost-{self.uuid.hex}" return f"ak-outpost-{self.uuid.hex}"
@property
def schedule_specs(self) -> list[ScheduleSpec]:
from authentik.outposts.tasks import outpost_controller
return [
ScheduleSpec(
actor=outpost_controller,
uid=self.pk,
args=(self.pk,),
kwargs={"action": "up", "from_cache": False},
crontab=f"{fqdn_rand('outpost_controller')} */4 * * *",
send_on_save=True,
),
]
def build_user_permissions(self, user: User): def build_user_permissions(self, user: User):
"""Create per-object and global permissions for outpost service-account""" """Create per-object and global permissions for outpost service-account"""
# To ensure the user only has the correct permissions, we delete all of them and re-add # To ensure the user only has the correct permissions, we delete all of them and re-add

View File

@ -0,0 +1,28 @@
"""Outposts Settings"""
from celery.schedules import crontab
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"outposts_controller": {
"task": "authentik.outposts.tasks.outpost_controller_all",
"schedule": crontab(minute=fqdn_rand("outposts_controller"), hour="*/4"),
"options": {"queue": "authentik_scheduled"},
},
"outposts_service_connection_check": {
"task": "authentik.outposts.tasks.outpost_service_connection_monitor",
"schedule": crontab(minute="3-59/15"),
"options": {"queue": "authentik_scheduled"},
},
"outpost_token_ensurer": {
"task": "authentik.outposts.tasks.outpost_token_ensurer",
"schedule": crontab(minute=fqdn_rand("outpost_token_ensurer"), hour="*/8"),
"options": {"queue": "authentik_scheduled"},
},
"outpost_connection_discovery": {
"task": "authentik.outposts.tasks.outpost_connection_discovery",
"schedule": crontab(minute=fqdn_rand("outpost_connection_discovery"), hour="*/8"),
"options": {"queue": "authentik_scheduled"},
},
}

View File

@ -1,6 +1,7 @@
"""authentik outpost signals""" """authentik outpost signals"""
from django.core.cache import cache from django.core.cache import cache
from django.db.models import Model
from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save
from django.dispatch import receiver from django.dispatch import receiver
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -8,19 +9,27 @@ from structlog.stdlib import get_logger
from authentik.brands.models import Brand from authentik.brands.models import Brand
from authentik.core.models import AuthenticatedSession, Provider from authentik.core.models import AuthenticatedSession, Provider
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.outposts.models import Outpost, OutpostModel, OutpostServiceConnection from authentik.lib.utils.reflection import class_to_path
from authentik.outposts.models import Outpost, OutpostServiceConnection
from authentik.outposts.tasks import ( from authentik.outposts.tasks import (
CACHE_KEY_OUTPOST_DOWN, CACHE_KEY_OUTPOST_DOWN,
outpost_controller, outpost_controller,
outpost_send_update, outpost_post_save,
outpost_session_end, outpost_session_end,
) )
LOGGER = get_logger() LOGGER = get_logger()
UPDATE_TRIGGERING_MODELS = (
Outpost,
OutpostServiceConnection,
Provider,
CertificateKeyPair,
Brand,
)
@receiver(pre_save, sender=Outpost) @receiver(pre_save, sender=Outpost)
def outpost_pre_save(sender, instance: Outpost, **_): def pre_save_outpost(sender, instance: Outpost, **_):
"""Pre-save checks for an outpost, if the name or config.kubernetes_namespace changes, """Pre-save checks for an outpost, if the name or config.kubernetes_namespace changes,
we call down and then wait for the up after save""" we call down and then wait for the up after save"""
old_instances = Outpost.objects.filter(pk=instance.pk) old_instances = Outpost.objects.filter(pk=instance.pk)
@ -35,89 +44,43 @@ def outpost_pre_save(sender, instance: Outpost, **_):
if bool(dirty): if bool(dirty):
LOGGER.info("Outpost needs re-deployment due to changes", instance=instance) LOGGER.info("Outpost needs re-deployment due to changes", instance=instance)
cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance) cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance)
outpost_controller.send_with_options( outpost_controller.delay(instance.pk.hex, action="down", from_cache=True)
args=(instance.pk.hex,),
kwargs={"action": "down", "from_cache": True},
rel_obj=instance,
)
@receiver(m2m_changed, sender=Outpost.providers.through) @receiver(m2m_changed, sender=Outpost.providers.through)
def outpost_m2m_changed(sender, instance: Outpost | Provider, action: str, **_): def m2m_changed_update(sender, instance: Model, action: str, **_):
"""Update outpost on m2m change, when providers are added or removed""" """Update outpost on m2m change, when providers are added or removed"""
if action not in ["post_add", "post_remove", "post_clear"]: if action in ["post_add", "post_remove", "post_clear"]:
outpost_post_save.delay(class_to_path(instance.__class__), instance.pk)
@receiver(post_save)
def post_save_update(sender, instance: Model, created: bool, **_):
"""If an Outpost is saved, Ensure that token is created/updated
If an OutpostModel, or a model that is somehow connected to an OutpostModel is saved,
we send a message down the relevant OutpostModels WS connection to trigger an update"""
if instance.__module__ == "django.db.migrations.recorder":
return return
if isinstance(instance, Outpost): if instance.__module__ == "__fake__":
outpost_controller.send_with_options( return
args=(instance.pk,), if not isinstance(instance, UPDATE_TRIGGERING_MODELS):
rel_obj=instance.service_connection, return
) if isinstance(instance, Outpost) and created:
outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance)
elif isinstance(instance, OutpostModel):
for outpost in instance.outpost_set.all():
outpost_controller.send_with_options(
args=(instance.pk,),
rel_obj=instance.service_connection,
)
outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost)
@receiver(post_save, sender=Outpost)
def outpost_post_save(sender, instance: Outpost, created: bool, **_):
if created:
LOGGER.info("New outpost saved, ensuring initial token and user are created") LOGGER.info("New outpost saved, ensuring initial token and user are created")
_ = instance.token _ = instance.token
outpost_controller.send_with_options(args=(instance.pk,), rel_obj=instance.service_connection) outpost_post_save.delay(class_to_path(instance.__class__), instance.pk)
outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance)
def outpost_related_post_save(sender, instance: OutpostServiceConnection | OutpostModel, **_):
for outpost in instance.outpost_set.all():
outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost)
post_save.connect(outpost_related_post_save, sender=OutpostServiceConnection, weak=False)
for subclass in OutpostModel.__subclasses__():
post_save.connect(outpost_related_post_save, sender=subclass, weak=False)
def outpost_reverse_related_post_save(sender, instance: CertificateKeyPair | Brand, **_):
for field in instance._meta.get_fields():
# Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms)
# are used, and if it has a value
if not hasattr(field, "related_model"):
continue
if not field.related_model:
continue
if not issubclass(field.related_model, OutpostModel):
continue
field_name = f"{field.name}_set"
if not hasattr(instance, field_name):
continue
LOGGER.debug("triggering outpost update from field", field=field.name)
# Because the Outpost Model has an M2M to Provider,
# we have to iterate over the entire QS
for reverse in getattr(instance, field_name).all():
if isinstance(reverse, OutpostModel):
for outpost in reverse.outpost_set.all():
outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost)
post_save.connect(outpost_reverse_related_post_save, sender=Brand, weak=False)
post_save.connect(outpost_reverse_related_post_save, sender=CertificateKeyPair, weak=False)
@receiver(pre_delete, sender=Outpost) @receiver(pre_delete, sender=Outpost)
def outpost_pre_delete_cleanup(sender, instance: Outpost, **_): def pre_delete_cleanup(sender, instance: Outpost, **_):
"""Ensure that Outpost's user is deleted (which will delete the token through cascade)""" """Ensure that Outpost's user is deleted (which will delete the token through cascade)"""
instance.user.delete() instance.user.delete()
cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance) cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance)
outpost_controller.send(instance.pk.hex, action="down", from_cache=True) outpost_controller.delay(instance.pk.hex, action="down", from_cache=True)
@receiver(pre_delete, sender=AuthenticatedSession) @receiver(pre_delete, sender=AuthenticatedSession)
def outpost_logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
"""Catch logout by expiring sessions being deleted""" """Catch logout by expiring sessions being deleted"""
outpost_session_end.send(instance.session.session_key) outpost_session_end.delay(instance.session.session_key)

View File

@ -10,17 +10,19 @@ from urllib.parse import urlparse
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer from channels.layers import get_channel_layer
from django.core.cache import cache from django.core.cache import cache
from django.db import DatabaseError, InternalError, ProgrammingError
from django.db.models.base import Model
from django.utils.text import slugify from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
from django_dramatiq_postgres.middleware import CurrentTask
from docker.constants import DEFAULT_UNIX_SOCKET from docker.constants import DEFAULT_UNIX_SOCKET
from dramatiq.actor import actor
from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME
from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from yaml import safe_load from yaml import safe_load
from authentik.events.models import TaskStatus
from authentik.events.system_tasks import SystemTask, prefill_task
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.lib.utils.reflection import path_to_class
from authentik.outposts.consumer import OUTPOST_GROUP from authentik.outposts.consumer import OUTPOST_GROUP
from authentik.outposts.controllers.base import BaseController, ControllerException from authentik.outposts.controllers.base import BaseController, ControllerException
from authentik.outposts.controllers.docker import DockerClient from authentik.outposts.controllers.docker import DockerClient
@ -29,6 +31,7 @@ from authentik.outposts.models import (
DockerServiceConnection, DockerServiceConnection,
KubernetesServiceConnection, KubernetesServiceConnection,
Outpost, Outpost,
OutpostModel,
OutpostServiceConnection, OutpostServiceConnection,
OutpostType, OutpostType,
ServiceConnectionInvalid, ServiceConnectionInvalid,
@ -41,7 +44,7 @@ from authentik.providers.rac.controllers.docker import RACDockerController
from authentik.providers.rac.controllers.kubernetes import RACKubernetesController from authentik.providers.rac.controllers.kubernetes import RACKubernetesController
from authentik.providers.radius.controllers.docker import RadiusDockerController from authentik.providers.radius.controllers.docker import RadiusDockerController
from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController
from authentik.tasks.models import Task from authentik.root.celery import CELERY_APP
LOGGER = get_logger() LOGGER = get_logger()
CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s" CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s"
@ -80,8 +83,8 @@ def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None:
return None return None
@actor(description=_("Update cached state of service connection.")) @CELERY_APP.task()
def outpost_service_connection_monitor(connection_pk: Any): def outpost_service_connection_state(connection_pk: Any):
"""Update cached state of a service connection""" """Update cached state of a service connection"""
connection: OutpostServiceConnection = ( connection: OutpostServiceConnection = (
OutpostServiceConnection.objects.filter(pk=connection_pk).select_subclasses().first() OutpostServiceConnection.objects.filter(pk=connection_pk).select_subclasses().first()
@ -105,11 +108,37 @@ def outpost_service_connection_monitor(connection_pk: Any):
cache.set(connection.state_key, state, timeout=None) cache.set(connection.state_key, state, timeout=None)
@actor(description=_("Create/update/monitor/delete the deployment of an Outpost.")) @CELERY_APP.task(
def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False): bind=True,
base=SystemTask,
throws=(DatabaseError, ProgrammingError, InternalError),
)
@prefill_task
def outpost_service_connection_monitor(self: SystemTask):
"""Regularly check the state of Outpost Service Connections"""
connections = OutpostServiceConnection.objects.all()
for connection in connections.iterator():
outpost_service_connection_state.delay(connection.pk)
self.set_status(
TaskStatus.SUCCESSFUL,
f"Successfully updated {len(connections)} connections.",
)
@CELERY_APP.task(
throws=(DatabaseError, ProgrammingError, InternalError),
)
def outpost_controller_all():
"""Launch Controller for all Outposts which support it"""
for outpost in Outpost.objects.exclude(service_connection=None):
outpost_controller.delay(outpost.pk.hex, "up", from_cache=False)
@CELERY_APP.task(bind=True, base=SystemTask)
def outpost_controller(
self: SystemTask, outpost_pk: str, action: str = "up", from_cache: bool = False
):
"""Create/update/monitor/delete the deployment of an Outpost""" """Create/update/monitor/delete the deployment of an Outpost"""
self: Task = CurrentTask.get_task()
self.set_uid(outpost_pk)
logs = [] logs = []
if from_cache: if from_cache:
outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk) outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
@ -130,65 +159,125 @@ def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = F
logs = getattr(controller, f"{action}_with_logs")() logs = getattr(controller, f"{action}_with_logs")()
LOGGER.debug("-----------------Outpost Controller logs end-------------------") LOGGER.debug("-----------------Outpost Controller logs end-------------------")
except (ControllerException, ServiceConnectionInvalid) as exc: except (ControllerException, ServiceConnectionInvalid) as exc:
self.error(exc) self.set_error(exc)
else: else:
if from_cache: if from_cache:
cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk) cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk)
self.logs(logs) self.set_status(TaskStatus.SUCCESSFUL, *logs)
@actor(description=_("Ensure that all Outposts have valid Service Accounts and Tokens.")) @CELERY_APP.task(bind=True, base=SystemTask)
def outpost_token_ensurer(): @prefill_task
""" def outpost_token_ensurer(self: SystemTask):
Periodically ensure that all Outposts have valid Service Accounts and Tokens """Periodically ensure that all Outposts have valid Service Accounts
""" and Tokens"""
self: Task = CurrentTask.get_task()
all_outposts = Outpost.objects.all() all_outposts = Outpost.objects.all()
for outpost in all_outposts: for outpost in all_outposts:
_ = outpost.token _ = outpost.token
outpost.build_user_permissions(outpost.user) outpost.build_user_permissions(outpost.user)
self.info(f"Successfully checked {len(all_outposts)} Outposts.") self.set_status(
TaskStatus.SUCCESSFUL,
f"Successfully checked {len(all_outposts)} Outposts.",
)
@actor(description=_("Send update to outpost")) @CELERY_APP.task()
def outpost_send_update(pk: Any): def outpost_post_save(model_class: str, model_pk: Any):
"""Update outpost instance""" """If an Outpost is saved, Ensure that token is created/updated
outpost = Outpost.objects.filter(pk=pk).first()
if not outpost: If an OutpostModel, or a model that is somehow connected to an OutpostModel is saved,
we send a message down the relevant OutpostModels WS connection to trigger an update"""
model: Model = path_to_class(model_class)
try:
instance = model.objects.get(pk=model_pk)
except model.DoesNotExist:
LOGGER.warning("Model does not exist", model=model, pk=model_pk)
return return
if isinstance(instance, Outpost):
LOGGER.debug("Trigger reconcile for outpost", instance=instance)
outpost_controller.delay(str(instance.pk))
if isinstance(instance, OutpostModel | Outpost):
LOGGER.debug("triggering outpost update from outpostmodel/outpost", instance=instance)
outpost_send_update(instance)
if isinstance(instance, OutpostServiceConnection):
LOGGER.debug("triggering ServiceConnection state update", instance=instance)
outpost_service_connection_state.delay(str(instance.pk))
for field in instance._meta.get_fields():
# Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms)
# are used, and if it has a value
if not hasattr(field, "related_model"):
continue
if not field.related_model:
continue
if not issubclass(field.related_model, OutpostModel):
continue
field_name = f"{field.name}_set"
if not hasattr(instance, field_name):
continue
LOGGER.debug("triggering outpost update from field", field=field.name)
# Because the Outpost Model has an M2M to Provider,
# we have to iterate over the entire QS
for reverse in getattr(instance, field_name).all():
outpost_send_update(reverse)
def outpost_send_update(model_instance: Model):
"""Send outpost update to all registered outposts, regardless to which authentik
instance they are connected"""
channel_layer = get_channel_layer()
if isinstance(model_instance, OutpostModel):
for outpost in model_instance.outpost_set.all():
_outpost_single_update(outpost, channel_layer)
elif isinstance(model_instance, Outpost):
_outpost_single_update(model_instance, channel_layer)
def _outpost_single_update(outpost: Outpost, layer=None):
"""Update outpost instances connected to a single outpost"""
# Ensure token again, because this function is called when anything related to an # Ensure token again, because this function is called when anything related to an
# OutpostModel is saved, so we can be sure permissions are right # OutpostModel is saved, so we can be sure permissions are right
_ = outpost.token _ = outpost.token
outpost.build_user_permissions(outpost.user) outpost.build_user_permissions(outpost.user)
layer = get_channel_layer() if not layer: # pragma: no cover
layer = get_channel_layer()
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
LOGGER.debug("sending update", channel=group, outpost=outpost) LOGGER.debug("sending update", channel=group, outpost=outpost)
async_to_sync(layer.group_send)(group, {"type": "event.update"}) async_to_sync(layer.group_send)(group, {"type": "event.update"})
@actor(description=_("Checks the local environment and create Service connections.")) @CELERY_APP.task(
def outpost_connection_discovery(): base=SystemTask,
bind=True,
)
def outpost_connection_discovery(self: SystemTask):
"""Checks the local environment and create Service connections.""" """Checks the local environment and create Service connections."""
self: Task = CurrentTask.get_task() messages = []
if not CONFIG.get_bool("outposts.discover"): if not CONFIG.get_bool("outposts.discover"):
self.info("Outpost integration discovery is disabled") messages.append("Outpost integration discovery is disabled")
self.set_status(TaskStatus.SUCCESSFUL, *messages)
return return
# Explicitly check against token filename, as that's # Explicitly check against token filename, as that's
# only present when the integration is enabled # only present when the integration is enabled
if Path(SERVICE_TOKEN_FILENAME).exists(): if Path(SERVICE_TOKEN_FILENAME).exists():
self.info("Detected in-cluster Kubernetes Config") messages.append("Detected in-cluster Kubernetes Config")
if not KubernetesServiceConnection.objects.filter(local=True).exists(): if not KubernetesServiceConnection.objects.filter(local=True).exists():
self.info("Created Service Connection for in-cluster") messages.append("Created Service Connection for in-cluster")
KubernetesServiceConnection.objects.create( KubernetesServiceConnection.objects.create(
name="Local Kubernetes Cluster", local=True, kubeconfig={} name="Local Kubernetes Cluster", local=True, kubeconfig={}
) )
# For development, check for the existence of a kubeconfig file # For development, check for the existence of a kubeconfig file
kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser() kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser()
if kubeconfig_path.exists(): if kubeconfig_path.exists():
self.info("Detected kubeconfig") messages.append("Detected kubeconfig")
kubeconfig_local_name = f"k8s-{gethostname()}" kubeconfig_local_name = f"k8s-{gethostname()}"
if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists(): if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists():
self.info("Creating kubeconfig Service Connection") messages.append("Creating kubeconfig Service Connection")
with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig: with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig:
KubernetesServiceConnection.objects.create( KubernetesServiceConnection.objects.create(
name=kubeconfig_local_name, name=kubeconfig_local_name,
@ -197,18 +286,20 @@ def outpost_connection_discovery():
unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path
socket = Path(unix_socket_path) socket = Path(unix_socket_path)
if socket.exists() and access(socket, R_OK): if socket.exists() and access(socket, R_OK):
self.info("Detected local docker socket") messages.append("Detected local docker socket")
if len(DockerServiceConnection.objects.filter(local=True)) == 0: if len(DockerServiceConnection.objects.filter(local=True)) == 0:
self.info("Created Service Connection for docker") messages.append("Created Service Connection for docker")
DockerServiceConnection.objects.create( DockerServiceConnection.objects.create(
name="Local Docker connection", name="Local Docker connection",
local=True, local=True,
url=unix_socket_path, url=unix_socket_path,
) )
self.set_status(TaskStatus.SUCCESSFUL, *messages)
@actor(description=_("Terminate session on all outposts.")) @CELERY_APP.task()
def outpost_session_end(session_id: str): def outpost_session_end(session_id: str):
"""Update outpost instances connected to a single outpost"""
layer = get_channel_layer() layer = get_channel_layer()
hashed_session_id = hash_session_key(session_id) hashed_session_id = hash_session_key(session_id)
for outpost in Outpost.objects.all(): for outpost in Outpost.objects.all():

View File

@ -37,7 +37,6 @@ class OutpostTests(TestCase):
# We add a provider, user should only have access to outpost and provider # We add a provider, user should only have access to outpost and provider
outpost.providers.add(provider) outpost.providers.add(provider)
provider.refresh_from_db()
permissions = UserObjectPermission.objects.filter(user=outpost.user).order_by( permissions = UserObjectPermission.objects.filter(user=outpost.user).order_by(
"content_type__model" "content_type__model"
) )

View File

@ -15,7 +15,6 @@ class AuthentikProviderProxyConfig(ManagedAppConfig):
def proxy_set_defaults(self): def proxy_set_defaults(self):
from authentik.providers.proxy.models import ProxyProvider from authentik.providers.proxy.models import ProxyProvider
# TODO: figure out if this can be in pre_save + post_save signals
for provider in ProxyProvider.objects.all(): for provider in ProxyProvider.objects.all():
provider.set_oauth_defaults() provider.set_oauth_defaults()
provider.save() provider.save()

View File

@ -1,13 +0,0 @@
"""Proxy provider signals"""
from django.db.models.signals import pre_delete
from django.dispatch import receiver
from authentik.core.models import AuthenticatedSession
from authentik.providers.proxy.tasks import proxy_on_logout
@receiver(pre_delete, sender=AuthenticatedSession)
def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
"""Catch logout by expiring sessions being deleted"""
proxy_on_logout.send(instance.session.session_key)

View File

@ -1,26 +0,0 @@
"""proxy provider tasks"""
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.utils.translation import gettext_lazy as _
from dramatiq.actor import actor
from authentik.outposts.consumer import OUTPOST_GROUP
from authentik.outposts.models import Outpost, OutpostType
from authentik.providers.oauth2.id_token import hash_session_key
@actor(description=_("Terminate session on Proxy outpost."))
def proxy_on_logout(session_id: str):
layer = get_channel_layer()
hashed_session_id = hash_session_key(session_id)
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
async_to_sync(layer.group_send)(
group,
{
"type": "event.provider.specific",
"sub_type": "logout",
"session_id": hashed_session_id,
},
)

View File

@ -17,7 +17,6 @@ from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.lib.utils.time import timedelta_string_validator from authentik.lib.utils.time import timedelta_string_validator
from authentik.outposts.models import OutpostModel
from authentik.policies.models import PolicyBindingModel from authentik.policies.models import PolicyBindingModel
LOGGER = get_logger() LOGGER = get_logger()
@ -38,7 +37,7 @@ class AuthenticationMode(models.TextChoices):
PROMPT = "prompt" PROMPT = "prompt"
class RACProvider(OutpostModel, Provider): class RACProvider(Provider):
"""Remotely access computers/servers via RDP/SSH/VNC.""" """Remotely access computers/servers via RDP/SSH/VNC."""
settings = models.JSONField(default=dict) settings = models.JSONField(default=dict)

View File

@ -44,5 +44,5 @@ class SCIMProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelVie
filterset_fields = ["name", "exclude_users_service_account", "url", "filter_group"] filterset_fields = ["name", "exclude_users_service_account", "url", "filter_group"]
search_fields = ["name", "url"] search_fields = ["name", "url"]
ordering = ["name", "url"] ordering = ["name", "url"]
sync_task = scim_sync sync_single_task = scim_sync
sync_objects_task = scim_sync_objects sync_objects_task = scim_sync_objects

View File

@ -3,6 +3,7 @@
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.providers.scim.models import SCIMProvider from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_sync, sync_tasks
from authentik.tenants.management import TenantCommand from authentik.tenants.management import TenantCommand
LOGGER = get_logger() LOGGER = get_logger()
@ -20,5 +21,4 @@ class Command(TenantCommand):
if not provider: if not provider:
LOGGER.warning("Provider does not exist", name=provider_name) LOGGER.warning("Provider does not exist", name=provider_name)
continue continue
for schedule in provider.schedules.all(): sync_tasks.trigger_single_task(provider, scim_sync).get()
schedule.send().get_result()

View File

@ -7,7 +7,6 @@ from django.db import models
from django.db.models import QuerySet from django.db.models import QuerySet
from django.templatetags.static import static from django.templatetags.static import static
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from dramatiq.actor import Actor
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
from authentik.core.models import BackchannelProvider, Group, PropertyMapping, User, UserTypes from authentik.core.models import BackchannelProvider, Group, PropertyMapping, User, UserTypes
@ -100,12 +99,6 @@ class SCIMProvider(OutgoingSyncProvider, BackchannelProvider):
def icon_url(self) -> str | None: def icon_url(self) -> str | None:
return static("authentik/sources/scim.png") return static("authentik/sources/scim.png")
@property
def sync_actor(self) -> Actor:
from authentik.providers.scim.tasks import scim_sync
return scim_sync
def client_for_model( def client_for_model(
self, model: type[User | Group | SCIMProviderUser | SCIMProviderGroup] self, model: type[User | Group | SCIMProviderUser | SCIMProviderGroup]
) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]: ) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]:

View File

@ -0,0 +1,13 @@
"""SCIM task Settings"""
from celery.schedules import crontab
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"providers_scim_sync": {
"task": "authentik.providers.scim.tasks.scim_sync_all",
"schedule": crontab(minute=fqdn_rand("scim_sync_all"), hour="*/4"),
"options": {"queue": "authentik_scheduled"},
},
}

View File

@ -2,10 +2,11 @@
from authentik.lib.sync.outgoing.signals import register_signals from authentik.lib.sync.outgoing.signals import register_signals
from authentik.providers.scim.models import SCIMProvider from authentik.providers.scim.models import SCIMProvider
from authentik.providers.scim.tasks import scim_sync_direct_dispatch, scim_sync_m2m_dispatch from authentik.providers.scim.tasks import scim_sync, scim_sync_direct, scim_sync_m2m
register_signals( register_signals(
SCIMProvider, SCIMProvider,
task_sync_direct_dispatch=scim_sync_direct_dispatch, task_sync_single=scim_sync,
task_sync_m2m_dispatch=scim_sync_m2m_dispatch, task_sync_direct=scim_sync_direct,
task_sync_m2m=scim_sync_m2m,
) )

View File

@ -1,40 +1,37 @@
"""SCIM Provider tasks""" """SCIM Provider tasks"""
from django.utils.translation import gettext_lazy as _ from authentik.events.system_tasks import SystemTask
from dramatiq.actor import actor from authentik.lib.sync.outgoing.exceptions import TransientSyncException
from authentik.lib.sync.outgoing.tasks import SyncTasks from authentik.lib.sync.outgoing.tasks import SyncTasks
from authentik.providers.scim.models import SCIMProvider from authentik.providers.scim.models import SCIMProvider
from authentik.root.celery import CELERY_APP
sync_tasks = SyncTasks(SCIMProvider) sync_tasks = SyncTasks(SCIMProvider)
@actor(description=_("Sync SCIM provider objects.")) @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
def scim_sync_objects(*args, **kwargs): def scim_sync_objects(*args, **kwargs):
return sync_tasks.sync_objects(*args, **kwargs) return sync_tasks.sync_objects(*args, **kwargs)
@actor(description=_("Full sync for SCIM provider.")) @CELERY_APP.task(
def scim_sync(provider_pk: int, *args, **kwargs): base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True
)
def scim_sync(self, provider_pk: int, *args, **kwargs):
"""Run full sync for SCIM provider""" """Run full sync for SCIM provider"""
return sync_tasks.sync(provider_pk, scim_sync_objects) return sync_tasks.sync_single(self, provider_pk, scim_sync_objects)
@actor(description=_("Sync a direct object (user, group) for SCIM provider.")) @CELERY_APP.task()
def scim_sync_all():
return sync_tasks.sync_all(scim_sync)
@CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
def scim_sync_direct(*args, **kwargs): def scim_sync_direct(*args, **kwargs):
return sync_tasks.sync_signal_direct(*args, **kwargs) return sync_tasks.sync_signal_direct(*args, **kwargs)
@actor(description=_("Dispatch syncs for a direct object (user, group) for SCIM providers.")) @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True)
def scim_sync_direct_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_direct_dispatch(scim_sync_direct, *args, **kwargs)
@actor(description=_("Sync a related object (memberships) for SCIM provider."))
def scim_sync_m2m(*args, **kwargs): def scim_sync_m2m(*args, **kwargs):
return sync_tasks.sync_signal_m2m(*args, **kwargs) return sync_tasks.sync_signal_m2m(*args, **kwargs)
@actor(description=_("Dispatch syncs for a related object (memberships) for SCIM providers."))
def scim_sync_m2m_dispatch(*args, **kwargs):
return sync_tasks.sync_signal_m2m_dispatch(scim_sync_m2m, *args, **kwargs)

View File

@ -8,7 +8,7 @@ from authentik.core.models import Application
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.providers.scim.clients.base import SCIMClient from authentik.providers.scim.clients.base import SCIMClient
from authentik.providers.scim.models import SCIMMapping, SCIMProvider from authentik.providers.scim.models import SCIMMapping, SCIMProvider
from authentik.providers.scim.tasks import scim_sync from authentik.providers.scim.tasks import scim_sync_all
class SCIMClientTests(TestCase): class SCIMClientTests(TestCase):
@ -85,6 +85,6 @@ class SCIMClientTests(TestCase):
self.assertEqual(mock.call_count, 1) self.assertEqual(mock.call_count, 1)
self.assertEqual(mock.request_history[0].method, "GET") self.assertEqual(mock.request_history[0].method, "GET")
def test_scim_sync(self): def test_scim_sync_all(self):
"""test scim_sync task""" """test scim_sync_all task"""
scim_sync.send(self.provider.pk).get_result() scim_sync_all()

View File

@ -8,7 +8,7 @@ from authentik.core.models import Application, Group, User
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
from authentik.providers.scim.models import SCIMMapping, SCIMProvider from authentik.providers.scim.models import SCIMMapping, SCIMProvider
from authentik.providers.scim.tasks import scim_sync from authentik.providers.scim.tasks import scim_sync, sync_tasks
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
@ -79,15 +79,17 @@ class SCIMMembershipTests(TestCase):
) )
self.configure() self.configure()
scim_sync.send(self.provider.pk) sync_tasks.trigger_single_task(self.provider, scim_sync).get()
self.assertEqual(mocker.call_count, 4) self.assertEqual(mocker.call_count, 6)
self.assertEqual(mocker.request_history[0].method, "GET") self.assertEqual(mocker.request_history[0].method, "GET")
self.assertEqual(mocker.request_history[1].method, "POST") self.assertEqual(mocker.request_history[1].method, "GET")
self.assertEqual(mocker.request_history[2].method, "GET") self.assertEqual(mocker.request_history[2].method, "GET")
self.assertEqual(mocker.request_history[3].method, "POST") self.assertEqual(mocker.request_history[3].method, "POST")
self.assertEqual(mocker.request_history[4].method, "GET")
self.assertEqual(mocker.request_history[5].method, "POST")
self.assertJSONEqual( self.assertJSONEqual(
mocker.request_history[1].body, mocker.request_history[3].body,
{ {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"emails": [], "emails": [],
@ -99,7 +101,7 @@ class SCIMMembershipTests(TestCase):
}, },
) )
self.assertJSONEqual( self.assertJSONEqual(
mocker.request_history[3].body, mocker.request_history[5].body,
{ {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk), "externalId": str(group.pk),
@ -167,15 +169,17 @@ class SCIMMembershipTests(TestCase):
) )
self.configure() self.configure()
scim_sync.send(self.provider.pk) sync_tasks.trigger_single_task(self.provider, scim_sync).get()
self.assertEqual(mocker.call_count, 4) self.assertEqual(mocker.call_count, 6)
self.assertEqual(mocker.request_history[0].method, "GET") self.assertEqual(mocker.request_history[0].method, "GET")
self.assertEqual(mocker.request_history[1].method, "POST") self.assertEqual(mocker.request_history[1].method, "GET")
self.assertEqual(mocker.request_history[2].method, "GET") self.assertEqual(mocker.request_history[2].method, "GET")
self.assertEqual(mocker.request_history[3].method, "POST") self.assertEqual(mocker.request_history[3].method, "POST")
self.assertEqual(mocker.request_history[4].method, "GET")
self.assertEqual(mocker.request_history[5].method, "POST")
self.assertJSONEqual( self.assertJSONEqual(
mocker.request_history[1].body, mocker.request_history[3].body,
{ {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"active": True, "active": True,
@ -187,7 +191,7 @@ class SCIMMembershipTests(TestCase):
}, },
) )
self.assertJSONEqual( self.assertJSONEqual(
mocker.request_history[3].body, mocker.request_history[5].body,
{ {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk), "externalId": str(group.pk),
@ -283,15 +287,17 @@ class SCIMMembershipTests(TestCase):
) )
self.configure() self.configure()
scim_sync.send(self.provider.pk) sync_tasks.trigger_single_task(self.provider, scim_sync).get()
self.assertEqual(mocker.call_count, 4) self.assertEqual(mocker.call_count, 6)
self.assertEqual(mocker.request_history[0].method, "GET") self.assertEqual(mocker.request_history[0].method, "GET")
self.assertEqual(mocker.request_history[1].method, "POST") self.assertEqual(mocker.request_history[1].method, "GET")
self.assertEqual(mocker.request_history[2].method, "GET") self.assertEqual(mocker.request_history[2].method, "GET")
self.assertEqual(mocker.request_history[3].method, "POST") self.assertEqual(mocker.request_history[3].method, "POST")
self.assertEqual(mocker.request_history[4].method, "GET")
self.assertEqual(mocker.request_history[5].method, "POST")
self.assertJSONEqual( self.assertJSONEqual(
mocker.request_history[1].body, mocker.request_history[3].body,
{ {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"emails": [], "emails": [],
@ -303,7 +309,7 @@ class SCIMMembershipTests(TestCase):
}, },
) )
self.assertJSONEqual( self.assertJSONEqual(
mocker.request_history[3].body, mocker.request_history[5].body,
{ {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk), "externalId": str(group.pk),

View File

@ -9,11 +9,11 @@ from requests_mock import Mocker
from authentik.blueprints.tests import apply_blueprint from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application, Group, User from authentik.core.models import Application, Group, User
from authentik.events.models import SystemTask
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.lib.sync.outgoing.base import SAFE_METHODS from authentik.lib.sync.outgoing.base import SAFE_METHODS
from authentik.providers.scim.models import SCIMMapping, SCIMProvider from authentik.providers.scim.models import SCIMMapping, SCIMProvider
from authentik.providers.scim.tasks import scim_sync, scim_sync_objects from authentik.providers.scim.tasks import scim_sync, sync_tasks
from authentik.tasks.models import Task
from authentik.tenants.models import Tenant from authentik.tenants.models import Tenant
@ -356,7 +356,7 @@ class SCIMUserTests(TestCase):
email=f"{uid}@goauthentik.io", email=f"{uid}@goauthentik.io",
) )
scim_sync.send(self.provider.pk) sync_tasks.trigger_single_task(self.provider, scim_sync).get()
self.assertEqual(mock.call_count, 5) self.assertEqual(mock.call_count, 5)
self.assertEqual(mock.request_history[0].method, "GET") self.assertEqual(mock.request_history[0].method, "GET")
@ -428,19 +428,14 @@ class SCIMUserTests(TestCase):
email=f"{uid}@goauthentik.io", email=f"{uid}@goauthentik.io",
) )
scim_sync.send(self.provider.pk) sync_tasks.trigger_single_task(self.provider, scim_sync).get()
self.assertEqual(mock.call_count, 3) self.assertEqual(mock.call_count, 3)
for request in mock.request_history: for request in mock.request_history:
self.assertIn(request.method, SAFE_METHODS) self.assertIn(request.method, SAFE_METHODS)
task = list( task = SystemTask.objects.filter(uid=slugify(self.provider.name)).first()
Task.objects.filter(
actor_name=scim_sync_objects.actor_name,
_uid=slugify(self.provider.name),
).order_by("-mtime")
)[1]
self.assertIsNotNone(task) self.assertIsNotNone(task)
drop_msg = task._messages[3] drop_msg = task.messages[3]
self.assertEqual(drop_msg["event"], "Dropping mutating request due to dry run") self.assertEqual(drop_msg["event"], "Dropping mutating request due to dry run")
self.assertIsNotNone(drop_msg["attributes"]["url"]) self.assertIsNotNone(drop_msg["attributes"]["url"])
self.assertIsNotNone(drop_msg["attributes"]["body"]) self.assertIsNotNone(drop_msg["attributes"]["body"])

View File

@ -1,6 +1,7 @@
"""authentik core celery""" """authentik core celery"""
import os import os
from collections.abc import Callable
from contextvars import ContextVar from contextvars import ContextVar
from logging.config import dictConfig from logging.config import dictConfig
from pathlib import Path from pathlib import Path
@ -15,9 +16,12 @@ from celery.signals import (
task_internal_error, task_internal_error,
task_postrun, task_postrun,
task_prerun, task_prerun,
worker_ready,
) )
from celery.worker.control import inspect_command from celery.worker.control import inspect_command
from django.conf import settings from django.conf import settings
from django.db import ProgrammingError
from django_tenants.utils import get_public_schema_name
from structlog.contextvars import STRUCTLOG_KEY_PREFIX from structlog.contextvars import STRUCTLOG_KEY_PREFIX
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp
@ -65,10 +69,7 @@ def task_postrun_hook(task_id: str, task, *args, retval=None, state=None, **kwar
"""Log task_id on worker""" """Log task_id on worker"""
CTX_TASK_ID.set(...) CTX_TASK_ID.set(...)
LOGGER.info( LOGGER.info(
"Task finished", "Task finished", task_id=task_id.replace("-", ""), task_name=task.__name__, state=state
task_id=task_id.replace("-", ""),
task_name=task.__name__,
state=state,
) )
@ -82,12 +83,51 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar
CTX_TASK_ID.set(...) CTX_TASK_ID.set(...)
if not should_ignore_exception(exception): if not should_ignore_exception(exception):
Event.new( Event.new(
EventAction.SYSTEM_EXCEPTION, EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id
message=exception_to_string(exception),
task_id=task_id,
).save() ).save()
def _get_startup_tasks_default_tenant() -> list[Callable]:
"""Get all tasks to be run on startup for the default tenant"""
from authentik.outposts.tasks import outpost_connection_discovery
return [
outpost_connection_discovery,
]
def _get_startup_tasks_all_tenants() -> list[Callable]:
"""Get all tasks to be run on startup for all tenants"""
return []
@worker_ready.connect
def worker_ready_hook(*args, **kwargs):
"""Run certain tasks on worker start"""
from authentik.tenants.models import Tenant
LOGGER.info("Dispatching startup tasks...")
def _run_task(task: Callable):
try:
task.delay()
except ProgrammingError as exc:
LOGGER.warning("Startup task failed", task=task, exc=exc)
for task in _get_startup_tasks_default_tenant():
with Tenant.objects.get(schema_name=get_public_schema_name()):
_run_task(task)
for task in _get_startup_tasks_all_tenants():
for tenant in Tenant.objects.filter(ready=True):
with tenant:
_run_task(task)
from authentik.blueprints.v1.tasks import start_blueprint_watcher
start_blueprint_watcher()
class LivenessProbe(bootsteps.StartStopStep): class LivenessProbe(bootsteps.StartStopStep):
"""Add a timed task to touch a temporary file for healthchecking reasons""" """Add a timed task to touch a temporary file for healthchecking reasons"""

View File

@ -4,9 +4,9 @@ import importlib
from collections import OrderedDict from collections import OrderedDict
from hashlib import sha512 from hashlib import sha512
from pathlib import Path from pathlib import Path
from tempfile import gettempdir
import orjson import orjson
from celery.schedules import crontab
from sentry_sdk import set_tag from sentry_sdk import set_tag
from xmlsec import enable_debug_trace from xmlsec import enable_debug_trace
@ -65,18 +65,14 @@ SHARED_APPS = [
"pgactivity", "pgactivity",
"pglock", "pglock",
"channels", "channels",
"django_dramatiq_postgres",
"authentik.tasks",
] ]
TENANT_APPS = [ TENANT_APPS = [
"django.contrib.auth", "django.contrib.auth",
"django.contrib.contenttypes", "django.contrib.contenttypes",
"django.contrib.sessions", "django.contrib.sessions",
"pgtrigger",
"authentik.admin", "authentik.admin",
"authentik.api", "authentik.api",
"authentik.crypto", "authentik.crypto",
"authentik.events",
"authentik.flows", "authentik.flows",
"authentik.outposts", "authentik.outposts",
"authentik.policies.dummy", "authentik.policies.dummy",
@ -124,7 +120,6 @@ TENANT_APPS = [
"authentik.stages.user_login", "authentik.stages.user_login",
"authentik.stages.user_logout", "authentik.stages.user_logout",
"authentik.stages.user_write", "authentik.stages.user_write",
"authentik.tasks.schedules",
"authentik.brands", "authentik.brands",
"authentik.blueprints", "authentik.blueprints",
"guardian", "guardian",
@ -171,7 +166,6 @@ SPECTACULAR_SETTINGS = {
"UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification", "UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification",
"UserTypeEnum": "authentik.core.models.UserTypes", "UserTypeEnum": "authentik.core.models.UserTypes",
"OutgoingSyncDeleteAction": "authentik.lib.sync.outgoing.models.OutgoingSyncDeleteAction", "OutgoingSyncDeleteAction": "authentik.lib.sync.outgoing.models.OutgoingSyncDeleteAction",
"TaskAggregatedStatusEnum": "authentik.tasks.models.TaskStatus",
}, },
"ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False, "ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False,
"ENUM_GENERATE_CHOICE_DESCRIPTION": False, "ENUM_GENERATE_CHOICE_DESCRIPTION": False,
@ -347,85 +341,37 @@ USE_TZ = True
LOCALE_PATHS = ["./locale"] LOCALE_PATHS = ["./locale"]
CELERY = {
# Tests "task_soft_time_limit": 600,
"worker_max_tasks_per_child": 50,
TEST = False "worker_concurrency": CONFIG.get_int("worker.concurrency"),
TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner" "beat_schedule": {
"clean_expired_models": {
"task": "authentik.core.tasks.clean_expired_models",
# Dramatiq "schedule": crontab(minute="2-59/5"),
"options": {"queue": "authentik_scheduled"},
DRAMATIQ = { },
"broker_class": "authentik.tasks.broker.Broker", "user_cleanup": {
"channel_prefix": "authentik", "task": "authentik.core.tasks.clean_temporary_users",
"task_model": "authentik.tasks.models.Task", "schedule": crontab(minute="9-59/5"),
"task_purge_interval": timedelta_from_string( "options": {"queue": "authentik_scheduled"},
CONFIG.get("worker.task_purge_interval") },
).total_seconds(),
"task_expiration": timedelta_from_string(CONFIG.get("worker.task_expiration")).total_seconds(),
"autodiscovery": {
"enabled": True,
"setup_module": "authentik.tasks.setup",
"apps_prefix": "authentik",
}, },
"worker": { "beat_scheduler": "authentik.tenants.scheduler:TenantAwarePersistentScheduler",
"processes": CONFIG.get_int("worker.processes", 2), "task_create_missing_queues": True,
"threads": CONFIG.get_int("worker.threads", 1), "task_default_queue": "authentik",
"consumer_listen_timeout": timedelta_from_string( "broker_url": CONFIG.get("broker.url") or redis_url(CONFIG.get("redis.db")),
CONFIG.get("worker.consumer_listen_timeout") "result_backend": CONFIG.get("result_backend.url") or redis_url(CONFIG.get("redis.db")),
).total_seconds(), "broker_transport_options": CONFIG.get_dict_from_b64_json(
"watch_folder": BASE_DIR / "authentik", "broker.transport_options", {"retry_policy": {"timeout": 5.0}}
},
"scheduler_class": "authentik.tasks.schedules.scheduler.Scheduler",
"schedule_model": "authentik.tasks.schedules.models.Schedule",
"scheduler_interval": timedelta_from_string(
CONFIG.get("worker.scheduler_interval")
).total_seconds(),
"middlewares": (
("django_dramatiq_postgres.middleware.FullyQualifiedActorName", {}),
# TODO: fixme
# ("dramatiq.middleware.prometheus.Prometheus", {}),
("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}),
("dramatiq.middleware.age_limit.AgeLimit", {}),
(
"dramatiq.middleware.time_limit.TimeLimit",
{
"time_limit": timedelta_from_string(
CONFIG.get("worker.task_default_time_limit")
).total_seconds()
* 1000
},
),
("dramatiq.middleware.shutdown.ShutdownNotifications", {}),
("dramatiq.middleware.callbacks.Callbacks", {}),
("dramatiq.middleware.pipelines.Pipelines", {}),
(
"dramatiq.middleware.retries.Retries",
{"max_retries": CONFIG.get_int("worker.task_max_retries") if not TEST else 0},
),
("dramatiq.results.middleware.Results", {"store_results": True}),
("django_dramatiq_postgres.middleware.CurrentTask", {}),
("authentik.tasks.middleware.TenantMiddleware", {}),
("authentik.tasks.middleware.RelObjMiddleware", {}),
("authentik.tasks.middleware.MessagesMiddleware", {}),
("authentik.tasks.middleware.LoggingMiddleware", {}),
("authentik.tasks.middleware.DescriptionMiddleware", {}),
("authentik.tasks.middleware.WorkerStatusMiddleware", {}),
(
"authentik.tasks.middleware.MetricsMiddleware",
{
"multiproc_dir": str(Path(gettempdir()) / "authentik_prometheus_tmp"),
"prefix": "authentik",
},
),
), ),
"test": TEST, "result_backend_transport_options": CONFIG.get_dict_from_b64_json(
"result_backend.transport_options", {"retry_policy": {"timeout": 5.0}}
),
"redis_retry_on_timeout": True,
} }
# Sentry integration # Sentry integration
env = get_env() env = get_env()
_ERROR_REPORTING = CONFIG.get_bool("error_reporting.enabled", False) _ERROR_REPORTING = CONFIG.get_bool("error_reporting.enabled", False)
if _ERROR_REPORTING: if _ERROR_REPORTING:
@ -486,6 +432,9 @@ else:
MEDIA_ROOT = STORAGES["default"]["OPTIONS"]["location"] MEDIA_ROOT = STORAGES["default"]["OPTIONS"]["location"]
MEDIA_URL = STORAGES["default"]["OPTIONS"]["base_url"] MEDIA_URL = STORAGES["default"]["OPTIONS"]["base_url"]
TEST = False
TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner"
structlog_configure() structlog_configure()
LOGGING = get_logger_config() LOGGING = get_logger_config()
@ -496,6 +445,7 @@ _DISALLOWED_ITEMS = [
"INSTALLED_APPS", "INSTALLED_APPS",
"MIDDLEWARE", "MIDDLEWARE",
"AUTHENTICATION_BACKENDS", "AUTHENTICATION_BACKENDS",
"CELERY",
"SPECTACULAR_SETTINGS", "SPECTACULAR_SETTINGS",
"REST_FRAMEWORK", "REST_FRAMEWORK",
] ]
@ -522,6 +472,7 @@ def _update_settings(app_path: str):
AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", [])) AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", []))
SPECTACULAR_SETTINGS.update(getattr(settings_module, "SPECTACULAR_SETTINGS", {})) SPECTACULAR_SETTINGS.update(getattr(settings_module, "SPECTACULAR_SETTINGS", {}))
REST_FRAMEWORK.update(getattr(settings_module, "REST_FRAMEWORK", {})) REST_FRAMEWORK.update(getattr(settings_module, "REST_FRAMEWORK", {}))
CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {}))
for _attr in dir(settings_module): for _attr in dir(settings_module):
if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS: if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS:
globals()[_attr] = getattr(settings_module, _attr) globals()[_attr] = getattr(settings_module, _attr)
@ -530,6 +481,7 @@ def _update_settings(app_path: str):
if DEBUG: if DEBUG:
CELERY["task_always_eager"] = True
REST_FRAMEWORK["DEFAULT_RENDERER_CLASSES"].append( REST_FRAMEWORK["DEFAULT_RENDERER_CLASSES"].append(
"rest_framework.renderers.BrowsableAPIRenderer" "rest_framework.renderers.BrowsableAPIRenderer"
) )
@ -549,6 +501,10 @@ try:
except ImportError: except ImportError:
pass pass
# Import events after other apps since it relies on tasks and other things from all apps
# being imported for @prefill_task
TENANT_APPS.append("authentik.events")
# Load subapps's settings # Load subapps's settings
for _app in set(SHARED_APPS + TENANT_APPS): for _app in set(SHARED_APPS + TENANT_APPS):

Some files were not shown because too many files have changed in this diff Show More