Compare commits
	
		
			7 Commits
		
	
	
		
			celery-2-d
			...
			dependabot
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 3d8be6d699 | |||
| b10c795a26 | |||
| 8088e08fd9 | |||
| eab6e288d7 | |||
| 91c2863358 | |||
| 1638e95bc7 | |||
| 8f75131541 | 
							
								
								
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -100,6 +100,9 @@ ipython_config.py | |||||||
| # pyenv | # pyenv | ||||||
| .python-version | .python-version | ||||||
|  |  | ||||||
|  | # celery beat schedule file | ||||||
|  | celerybeat-schedule | ||||||
|  |  | ||||||
| # SageMath parsed files | # SageMath parsed files | ||||||
| *.sage.py | *.sage.py | ||||||
|  |  | ||||||
| @ -163,6 +166,8 @@ dmypy.json | |||||||
|  |  | ||||||
| # pyenv | # pyenv | ||||||
|  |  | ||||||
|  | # celery beat schedule file | ||||||
|  |  | ||||||
| # SageMath parsed files | # SageMath parsed files | ||||||
|  |  | ||||||
| # Environments | # Environments | ||||||
|  | |||||||
| @ -122,7 +122,6 @@ ENV UV_NO_BINARY_PACKAGE="cryptography lxml python-kadmin-rs xmlsec" | |||||||
|  |  | ||||||
| RUN --mount=type=bind,target=pyproject.toml,src=pyproject.toml \ | RUN --mount=type=bind,target=pyproject.toml,src=pyproject.toml \ | ||||||
|     --mount=type=bind,target=uv.lock,src=uv.lock \ |     --mount=type=bind,target=uv.lock,src=uv.lock \ | ||||||
|     --mount=type=bind,target=packages,src=packages \ |  | ||||||
|     --mount=type=cache,target=/root/.cache/uv \ |     --mount=type=cache,target=/root/.cache/uv \ | ||||||
|     uv sync --frozen --no-install-project --no-dev |     uv sync --frozen --no-install-project --no-dev | ||||||
|  |  | ||||||
| @ -168,7 +167,6 @@ COPY ./blueprints /blueprints | |||||||
| COPY ./lifecycle/ /lifecycle | COPY ./lifecycle/ /lifecycle | ||||||
| COPY ./authentik/sources/kerberos/krb5.conf /etc/krb5.conf | COPY ./authentik/sources/kerberos/krb5.conf /etc/krb5.conf | ||||||
| COPY --from=go-builder /go/authentik /bin/authentik | COPY --from=go-builder /go/authentik /bin/authentik | ||||||
| COPY ./packages/ /ak-root/packages |  | ||||||
| COPY --from=python-deps /ak-root/.venv /ak-root/.venv | COPY --from=python-deps /ak-root/.venv /ak-root/.venv | ||||||
| COPY --from=node-builder /work/web/dist/ /web/dist/ | COPY --from=node-builder /work/web/dist/ /web/dist/ | ||||||
| COPY --from=node-builder /work/web/authentik/ /web/authentik/ | COPY --from=node-builder /work/web/authentik/ /web/authentik/ | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								Makefile
									
									
									
									
									
								
							| @ -6,7 +6,7 @@ PWD = $(shell pwd) | |||||||
| UID = $(shell id -u) | UID = $(shell id -u) | ||||||
| GID = $(shell id -g) | GID = $(shell id -g) | ||||||
| NPM_VERSION = $(shell python -m scripts.generate_semver) | NPM_VERSION = $(shell python -m scripts.generate_semver) | ||||||
| PY_SOURCES = authentik packages tests scripts lifecycle .github | PY_SOURCES = authentik tests scripts lifecycle .github | ||||||
| DOCKER_IMAGE ?= "authentik:test" | DOCKER_IMAGE ?= "authentik:test" | ||||||
|  |  | ||||||
| GEN_API_TS = gen-ts-api | GEN_API_TS = gen-ts-api | ||||||
|  | |||||||
| @ -41,7 +41,7 @@ class VersionSerializer(PassiveSerializer): | |||||||
|             return __version__ |             return __version__ | ||||||
|         version_in_cache = cache.get(VERSION_CACHE_KEY) |         version_in_cache = cache.get(VERSION_CACHE_KEY) | ||||||
|         if not version_in_cache:  # pragma: no cover |         if not version_in_cache:  # pragma: no cover | ||||||
|             update_latest_version.send() |             update_latest_version.delay() | ||||||
|             return __version__ |             return __version__ | ||||||
|         return version_in_cache |         return version_in_cache | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										57
									
								
								authentik/admin/api/workers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								authentik/admin/api/workers.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,57 @@ | |||||||
|  | """authentik administration overview""" | ||||||
|  |  | ||||||
|  | from socket import gethostname | ||||||
|  |  | ||||||
|  | from django.conf import settings | ||||||
|  | from drf_spectacular.utils import extend_schema, inline_serializer | ||||||
|  | from packaging.version import parse | ||||||
|  | from rest_framework.fields import BooleanField, CharField | ||||||
|  | from rest_framework.request import Request | ||||||
|  | from rest_framework.response import Response | ||||||
|  | from rest_framework.views import APIView | ||||||
|  |  | ||||||
|  | from authentik import get_full_version | ||||||
|  | from authentik.rbac.permissions import HasPermission | ||||||
|  | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class WorkerView(APIView): | ||||||
|  |     """Get currently connected worker count.""" | ||||||
|  |  | ||||||
|  |     permission_classes = [HasPermission("authentik_rbac.view_system_info")] | ||||||
|  |  | ||||||
|  |     @extend_schema( | ||||||
|  |         responses=inline_serializer( | ||||||
|  |             "Worker", | ||||||
|  |             fields={ | ||||||
|  |                 "worker_id": CharField(), | ||||||
|  |                 "version": CharField(), | ||||||
|  |                 "version_matching": BooleanField(), | ||||||
|  |             }, | ||||||
|  |             many=True, | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |     def get(self, request: Request) -> Response: | ||||||
|  |         """Get currently connected worker count.""" | ||||||
|  |         raw: list[dict[str, dict]] = CELERY_APP.control.ping(timeout=0.5) | ||||||
|  |         our_version = parse(get_full_version()) | ||||||
|  |         response = [] | ||||||
|  |         for worker in raw: | ||||||
|  |             key = list(worker.keys())[0] | ||||||
|  |             version = worker[key].get("version") | ||||||
|  |             version_matching = False | ||||||
|  |             if version: | ||||||
|  |                 version_matching = parse(version) == our_version | ||||||
|  |             response.append( | ||||||
|  |                 {"worker_id": key, "version": version, "version_matching": version_matching} | ||||||
|  |             ) | ||||||
|  |         # In debug we run with `task_always_eager`, so tasks are ran on the main process | ||||||
|  |         if settings.DEBUG:  # pragma: no cover | ||||||
|  |             response.append( | ||||||
|  |                 { | ||||||
|  |                     "worker_id": f"authentik-debug@{gethostname()}", | ||||||
|  |                     "version": get_full_version(), | ||||||
|  |                     "version_matching": True, | ||||||
|  |                 } | ||||||
|  |             ) | ||||||
|  |         return Response(response) | ||||||
| @ -3,9 +3,6 @@ | |||||||
| from prometheus_client import Info | from prometheus_client import Info | ||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
| from authentik.lib.config import CONFIG |  | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
| from authentik.tasks.schedules.lib import ScheduleSpec |  | ||||||
|  |  | ||||||
| PROM_INFO = Info("authentik_version", "Currently running authentik version") | PROM_INFO = Info("authentik_version", "Currently running authentik version") | ||||||
|  |  | ||||||
| @ -33,15 +30,3 @@ class AuthentikAdminConfig(ManagedAppConfig): | |||||||
|             notification_version = notification.event.context["new_version"] |             notification_version = notification.event.context["new_version"] | ||||||
|             if LOCAL_VERSION >= parse(notification_version): |             if LOCAL_VERSION >= parse(notification_version): | ||||||
|                 notification.delete() |                 notification.delete() | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def global_schedule_specs(self) -> list[ScheduleSpec]: |  | ||||||
|         from authentik.admin.tasks import update_latest_version |  | ||||||
|  |  | ||||||
|         return [ |  | ||||||
|             ScheduleSpec( |  | ||||||
|                 actor=update_latest_version, |  | ||||||
|                 crontab=f"{fqdn_rand('admin_latest_version')} * * * *", |  | ||||||
|                 paused=CONFIG.get_bool("disable_update_check"), |  | ||||||
|             ), |  | ||||||
|         ] |  | ||||||
|  | |||||||
							
								
								
									
										15
									
								
								authentik/admin/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								authentik/admin/settings.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,15 @@ | |||||||
|  | """authentik admin settings""" | ||||||
|  |  | ||||||
|  | from celery.schedules import crontab | ||||||
|  | from django_tenants.utils import get_public_schema_name | ||||||
|  |  | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
|  |  | ||||||
|  | CELERY_BEAT_SCHEDULE = { | ||||||
|  |     "admin_latest_version": { | ||||||
|  |         "task": "authentik.admin.tasks.update_latest_version", | ||||||
|  |         "schedule": crontab(minute=fqdn_rand("admin_latest_version"), hour="*"), | ||||||
|  |         "tenant_schemas": [get_public_schema_name()], | ||||||
|  |         "options": {"queue": "authentik_scheduled"}, | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										35
									
								
								authentik/admin/signals.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								authentik/admin/signals.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,35 @@ | |||||||
|  | """admin signals""" | ||||||
|  |  | ||||||
|  | from django.dispatch import receiver | ||||||
|  | from packaging.version import parse | ||||||
|  | from prometheus_client import Gauge | ||||||
|  |  | ||||||
|  | from authentik import get_full_version | ||||||
|  | from authentik.root.celery import CELERY_APP | ||||||
|  | from authentik.root.monitoring import monitoring_set | ||||||
|  |  | ||||||
|  | GAUGE_WORKERS = Gauge( | ||||||
|  |     "authentik_admin_workers", | ||||||
|  |     "Currently connected workers, their versions and if they are the same version as authentik", | ||||||
|  |     ["version", "version_matched"], | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | _version = parse(get_full_version()) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(monitoring_set) | ||||||
|  | def monitoring_set_workers(sender, **kwargs): | ||||||
|  |     """Set worker gauge""" | ||||||
|  |     raw: list[dict[str, dict]] = CELERY_APP.control.ping(timeout=0.5) | ||||||
|  |     worker_version_count = {} | ||||||
|  |     for worker in raw: | ||||||
|  |         key = list(worker.keys())[0] | ||||||
|  |         version = worker[key].get("version") | ||||||
|  |         version_matching = False | ||||||
|  |         if version: | ||||||
|  |             version_matching = parse(version) == _version | ||||||
|  |         worker_version_count.setdefault(version, {"count": 0, "matching": version_matching}) | ||||||
|  |         worker_version_count[version]["count"] += 1 | ||||||
|  |     for version, stats in worker_version_count.items(): | ||||||
|  |         GAUGE_WORKERS.labels(version, stats["matching"]).set(stats["count"]) | ||||||
| @ -2,8 +2,6 @@ | |||||||
|  |  | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from django_dramatiq_postgres.middleware import CurrentTask |  | ||||||
| from dramatiq import actor |  | ||||||
| from packaging.version import parse | from packaging.version import parse | ||||||
| from requests import RequestException | from requests import RequestException | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| @ -11,9 +9,10 @@ from structlog.stdlib import get_logger | |||||||
| from authentik import __version__, get_build_hash | from authentik import __version__, get_build_hash | ||||||
| from authentik.admin.apps import PROM_INFO | from authentik.admin.apps import PROM_INFO | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
|  | from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.lib.utils.http import get_http_session | from authentik.lib.utils.http import get_http_session | ||||||
| from authentik.tasks.models import Task | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| VERSION_NULL = "0.0.0" | VERSION_NULL = "0.0.0" | ||||||
| @ -33,12 +32,13 @@ def _set_prom_info(): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Update latest version info.")) | @CELERY_APP.task(bind=True, base=SystemTask) | ||||||
| def update_latest_version(): | @prefill_task | ||||||
|     self: Task = CurrentTask.get_task() | def update_latest_version(self: SystemTask): | ||||||
|  |     """Update latest version info""" | ||||||
|     if CONFIG.get_bool("disable_update_check"): |     if CONFIG.get_bool("disable_update_check"): | ||||||
|         cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) |         cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) | ||||||
|         self.info("Version check disabled.") |         self.set_status(TaskStatus.WARNING, "Version check disabled.") | ||||||
|         return |         return | ||||||
|     try: |     try: | ||||||
|         response = get_http_session().get( |         response = get_http_session().get( | ||||||
| @ -48,7 +48,7 @@ def update_latest_version(): | |||||||
|         data = response.json() |         data = response.json() | ||||||
|         upstream_version = data.get("stable", {}).get("version") |         upstream_version = data.get("stable", {}).get("version") | ||||||
|         cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT) |         cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT) | ||||||
|         self.info("Successfully updated latest Version") |         self.set_status(TaskStatus.SUCCESSFUL, "Successfully updated latest Version") | ||||||
|         _set_prom_info() |         _set_prom_info() | ||||||
|         # Check if upstream version is newer than what we're running, |         # Check if upstream version is newer than what we're running, | ||||||
|         # and if no event exists yet, create one. |         # and if no event exists yet, create one. | ||||||
| @ -71,7 +71,7 @@ def update_latest_version(): | |||||||
|             ).save() |             ).save() | ||||||
|     except (RequestException, IndexError) as exc: |     except (RequestException, IndexError) as exc: | ||||||
|         cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) |         cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) | ||||||
|         raise exc |         self.set_error(exc) | ||||||
|  |  | ||||||
|  |  | ||||||
| _set_prom_info() | _set_prom_info() | ||||||
|  | |||||||
| @ -29,6 +29,13 @@ class TestAdminAPI(TestCase): | |||||||
|         body = loads(response.content) |         body = loads(response.content) | ||||||
|         self.assertEqual(body["version_current"], __version__) |         self.assertEqual(body["version_current"], __version__) | ||||||
|  |  | ||||||
|  |     def test_workers(self): | ||||||
|  |         """Test Workers API""" | ||||||
|  |         response = self.client.get(reverse("authentik_api:admin_workers")) | ||||||
|  |         self.assertEqual(response.status_code, 200) | ||||||
|  |         body = loads(response.content) | ||||||
|  |         self.assertEqual(len(body), 0) | ||||||
|  |  | ||||||
|     def test_apps(self): |     def test_apps(self): | ||||||
|         """Test apps API""" |         """Test apps API""" | ||||||
|         response = self.client.get(reverse("authentik_api:apps-list")) |         response = self.client.get(reverse("authentik_api:apps-list")) | ||||||
|  | |||||||
| @ -30,7 +30,7 @@ class TestAdminTasks(TestCase): | |||||||
|         """Test Update checker with valid response""" |         """Test Update checker with valid response""" | ||||||
|         with Mocker() as mocker, CONFIG.patch("disable_update_check", False): |         with Mocker() as mocker, CONFIG.patch("disable_update_check", False): | ||||||
|             mocker.get("https://version.goauthentik.io/version.json", json=RESPONSE_VALID) |             mocker.get("https://version.goauthentik.io/version.json", json=RESPONSE_VALID) | ||||||
|             update_latest_version.send() |             update_latest_version.delay().get() | ||||||
|             self.assertEqual(cache.get(VERSION_CACHE_KEY), "99999999.9999999") |             self.assertEqual(cache.get(VERSION_CACHE_KEY), "99999999.9999999") | ||||||
|             self.assertTrue( |             self.assertTrue( | ||||||
|                 Event.objects.filter( |                 Event.objects.filter( | ||||||
| @ -40,7 +40,7 @@ class TestAdminTasks(TestCase): | |||||||
|                 ).exists() |                 ).exists() | ||||||
|             ) |             ) | ||||||
|             # test that a consecutive check doesn't create a duplicate event |             # test that a consecutive check doesn't create a duplicate event | ||||||
|             update_latest_version.send() |             update_latest_version.delay().get() | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 len( |                 len( | ||||||
|                     Event.objects.filter( |                     Event.objects.filter( | ||||||
| @ -56,7 +56,7 @@ class TestAdminTasks(TestCase): | |||||||
|         """Test Update checker with invalid response""" |         """Test Update checker with invalid response""" | ||||||
|         with Mocker() as mocker: |         with Mocker() as mocker: | ||||||
|             mocker.get("https://version.goauthentik.io/version.json", status_code=400) |             mocker.get("https://version.goauthentik.io/version.json", status_code=400) | ||||||
|             update_latest_version.send() |             update_latest_version.delay().get() | ||||||
|             self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") |             self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") | ||||||
|             self.assertFalse( |             self.assertFalse( | ||||||
|                 Event.objects.filter( |                 Event.objects.filter( | ||||||
| @ -67,15 +67,14 @@ class TestAdminTasks(TestCase): | |||||||
|     def test_version_disabled(self): |     def test_version_disabled(self): | ||||||
|         """Test Update checker while its disabled""" |         """Test Update checker while its disabled""" | ||||||
|         with CONFIG.patch("disable_update_check", True): |         with CONFIG.patch("disable_update_check", True): | ||||||
|             update_latest_version.send() |             update_latest_version.delay().get() | ||||||
|             self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") |             self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") | ||||||
|  |  | ||||||
|     def test_clear_update_notifications(self): |     def test_clear_update_notifications(self): | ||||||
|         """Test clear of previous notification""" |         """Test clear of previous notification""" | ||||||
|         admin_config = apps.get_app_config("authentik_admin") |         admin_config = apps.get_app_config("authentik_admin") | ||||||
|         Event.objects.create( |         Event.objects.create( | ||||||
|             action=EventAction.UPDATE_AVAILABLE, |             action=EventAction.UPDATE_AVAILABLE, context={"new_version": "99999999.9999999.9999999"} | ||||||
|             context={"new_version": "99999999.9999999.9999999"}, |  | ||||||
|         ) |         ) | ||||||
|         Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={"new_version": "1.1.1"}) |         Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={"new_version": "1.1.1"}) | ||||||
|         Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={}) |         Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={}) | ||||||
|  | |||||||
| @ -6,11 +6,13 @@ from authentik.admin.api.meta import AppsViewSet, ModelViewSet | |||||||
| from authentik.admin.api.system import SystemView | from authentik.admin.api.system import SystemView | ||||||
| from authentik.admin.api.version import VersionView | from authentik.admin.api.version import VersionView | ||||||
| from authentik.admin.api.version_history import VersionHistoryViewSet | from authentik.admin.api.version_history import VersionHistoryViewSet | ||||||
|  | from authentik.admin.api.workers import WorkerView | ||||||
|  |  | ||||||
| api_urlpatterns = [ | api_urlpatterns = [ | ||||||
|     ("admin/apps", AppsViewSet, "apps"), |     ("admin/apps", AppsViewSet, "apps"), | ||||||
|     ("admin/models", ModelViewSet, "models"), |     ("admin/models", ModelViewSet, "models"), | ||||||
|     path("admin/version/", VersionView.as_view(), name="admin_version"), |     path("admin/version/", VersionView.as_view(), name="admin_version"), | ||||||
|     ("admin/version/history", VersionHistoryViewSet, "version_history"), |     ("admin/version/history", VersionHistoryViewSet, "version_history"), | ||||||
|  |     path("admin/workers/", WorkerView.as_view(), name="admin_workers"), | ||||||
|     path("admin/system/", SystemView.as_view(), name="admin_system"), |     path("admin/system/", SystemView.as_view(), name="admin_system"), | ||||||
| ] | ] | ||||||
|  | |||||||
| @ -39,7 +39,7 @@ class BlueprintInstanceSerializer(ModelSerializer): | |||||||
|         """Ensure the path (if set) specified is retrievable""" |         """Ensure the path (if set) specified is retrievable""" | ||||||
|         if path == "" or path.startswith(OCI_PREFIX): |         if path == "" or path.startswith(OCI_PREFIX): | ||||||
|             return path |             return path | ||||||
|         files: list[dict] = blueprints_find_dict.send().get_result(block=True) |         files: list[dict] = blueprints_find_dict.delay().get() | ||||||
|         if path not in [file["path"] for file in files]: |         if path not in [file["path"] for file in files]: | ||||||
|             raise ValidationError(_("Blueprint file does not exist")) |             raise ValidationError(_("Blueprint file does not exist")) | ||||||
|         return path |         return path | ||||||
| @ -115,7 +115,7 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet): | |||||||
|     @action(detail=False, pagination_class=None, filter_backends=[]) |     @action(detail=False, pagination_class=None, filter_backends=[]) | ||||||
|     def available(self, request: Request) -> Response: |     def available(self, request: Request) -> Response: | ||||||
|         """Get blueprints""" |         """Get blueprints""" | ||||||
|         files: list[dict] = blueprints_find_dict.send().get_result(block=True) |         files: list[dict] = blueprints_find_dict.delay().get() | ||||||
|         return Response(files) |         return Response(files) | ||||||
|  |  | ||||||
|     @permission_required("authentik_blueprints.view_blueprintinstance") |     @permission_required("authentik_blueprints.view_blueprintinstance") | ||||||
| @ -129,5 +129,5 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet): | |||||||
|     def apply(self, request: Request, *args, **kwargs) -> Response: |     def apply(self, request: Request, *args, **kwargs) -> Response: | ||||||
|         """Apply a blueprint""" |         """Apply a blueprint""" | ||||||
|         blueprint = self.get_object() |         blueprint = self.get_object() | ||||||
|         apply_blueprint.send_with_options(args=(blueprint.pk,), rel_obj=blueprint) |         apply_blueprint.delay(str(blueprint.pk)).get() | ||||||
|         return self.retrieve(request, *args, **kwargs) |         return self.retrieve(request, *args, **kwargs) | ||||||
|  | |||||||
| @ -6,12 +6,9 @@ from inspect import ismethod | |||||||
|  |  | ||||||
| from django.apps import AppConfig | from django.apps import AppConfig | ||||||
| from django.db import DatabaseError, InternalError, ProgrammingError | from django.db import DatabaseError, InternalError, ProgrammingError | ||||||
| from dramatiq.broker import get_broker |  | ||||||
| from structlog.stdlib import BoundLogger, get_logger | from structlog.stdlib import BoundLogger, get_logger | ||||||
|  |  | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
| from authentik.root.signals import startup | from authentik.root.signals import startup | ||||||
| from authentik.tasks.schedules.lib import ScheduleSpec |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ManagedAppConfig(AppConfig): | class ManagedAppConfig(AppConfig): | ||||||
| @ -37,7 +34,7 @@ class ManagedAppConfig(AppConfig): | |||||||
|  |  | ||||||
|     def import_related(self): |     def import_related(self): | ||||||
|         """Automatically import related modules which rely on just being imported |         """Automatically import related modules which rely on just being imported | ||||||
|         to register themselves (mainly django signals and tasks)""" |         to register themselves (mainly django signals and celery tasks)""" | ||||||
|  |  | ||||||
|         def import_relative(rel_module: str): |         def import_relative(rel_module: str): | ||||||
|             try: |             try: | ||||||
| @ -83,16 +80,6 @@ class ManagedAppConfig(AppConfig): | |||||||
|         func._authentik_managed_reconcile = ManagedAppConfig.RECONCILE_GLOBAL_CATEGORY |         func._authentik_managed_reconcile = ManagedAppConfig.RECONCILE_GLOBAL_CATEGORY | ||||||
|         return func |         return func | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def tenant_schedule_specs(self) -> list[ScheduleSpec]: |  | ||||||
|         """Get a list of schedule specs that must exist in each tenant""" |  | ||||||
|         return [] |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def global_schedule_specs(self) -> list[ScheduleSpec]: |  | ||||||
|         """Get a list of schedule specs that must exist in the default tenant""" |  | ||||||
|         return [] |  | ||||||
|  |  | ||||||
|     def _reconcile_tenant(self) -> None: |     def _reconcile_tenant(self) -> None: | ||||||
|         """reconcile ourselves for tenanted methods""" |         """reconcile ourselves for tenanted methods""" | ||||||
|         from authentik.tenants.models import Tenant |         from authentik.tenants.models import Tenant | ||||||
| @ -113,12 +100,8 @@ class ManagedAppConfig(AppConfig): | |||||||
|         """ |         """ | ||||||
|         from django_tenants.utils import get_public_schema_name, schema_context |         from django_tenants.utils import get_public_schema_name, schema_context | ||||||
|  |  | ||||||
|         try: |  | ||||||
|         with schema_context(get_public_schema_name()): |         with schema_context(get_public_schema_name()): | ||||||
|             self._reconcile(self.RECONCILE_GLOBAL_CATEGORY) |             self._reconcile(self.RECONCILE_GLOBAL_CATEGORY) | ||||||
|         except (DatabaseError, ProgrammingError, InternalError) as exc: |  | ||||||
|             self.logger.debug("Failed to access database to run reconcile", exc=exc) |  | ||||||
|             return |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikBlueprintsConfig(ManagedAppConfig): | class AuthentikBlueprintsConfig(ManagedAppConfig): | ||||||
| @ -129,29 +112,19 @@ class AuthentikBlueprintsConfig(ManagedAppConfig): | |||||||
|     verbose_name = "authentik Blueprints" |     verbose_name = "authentik Blueprints" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|  |     @ManagedAppConfig.reconcile_global | ||||||
|  |     def load_blueprints_v1_tasks(self): | ||||||
|  |         """Load v1 tasks""" | ||||||
|  |         self.import_module("authentik.blueprints.v1.tasks") | ||||||
|  |  | ||||||
|  |     @ManagedAppConfig.reconcile_tenant | ||||||
|  |     def blueprints_discovery(self): | ||||||
|  |         """Run blueprint discovery""" | ||||||
|  |         from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints | ||||||
|  |  | ||||||
|  |         blueprints_discovery.delay() | ||||||
|  |         clear_failed_blueprints.delay() | ||||||
|  |  | ||||||
|     def import_models(self): |     def import_models(self): | ||||||
|         super().import_models() |         super().import_models() | ||||||
|         self.import_module("authentik.blueprints.v1.meta.apply_blueprint") |         self.import_module("authentik.blueprints.v1.meta.apply_blueprint") | ||||||
|  |  | ||||||
|     @ManagedAppConfig.reconcile_global |  | ||||||
|     def tasks_middlewares(self): |  | ||||||
|         from authentik.blueprints.v1.tasks import BlueprintWatcherMiddleware |  | ||||||
|  |  | ||||||
|         get_broker().add_middleware(BlueprintWatcherMiddleware()) |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def tenant_schedule_specs(self) -> list[ScheduleSpec]: |  | ||||||
|         from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints |  | ||||||
|  |  | ||||||
|         return [ |  | ||||||
|             ScheduleSpec( |  | ||||||
|                 actor=blueprints_discovery, |  | ||||||
|                 crontab=f"{fqdn_rand('blueprints_v1_discover')} * * * *", |  | ||||||
|                 send_on_startup=True, |  | ||||||
|             ), |  | ||||||
|             ScheduleSpec( |  | ||||||
|                 actor=clear_failed_blueprints, |  | ||||||
|                 crontab=f"{fqdn_rand('blueprints_v1_cleanup')} * * * *", |  | ||||||
|                 send_on_startup=True, |  | ||||||
|             ), |  | ||||||
|         ] |  | ||||||
|  | |||||||
| @ -3,7 +3,6 @@ | |||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from django.contrib.contenttypes.fields import GenericRelation |  | ||||||
| from django.contrib.postgres.fields import ArrayField | from django.contrib.postgres.fields import ArrayField | ||||||
| from django.db import models | from django.db import models | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| @ -72,13 +71,6 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): | |||||||
|     enabled = models.BooleanField(default=True) |     enabled = models.BooleanField(default=True) | ||||||
|     managed_models = ArrayField(models.TextField(), default=list) |     managed_models = ArrayField(models.TextField(), default=list) | ||||||
|  |  | ||||||
|     # Manual link to tasks instead of using TasksModel because of loop imports |  | ||||||
|     tasks = GenericRelation( |  | ||||||
|         "authentik_tasks.Task", |  | ||||||
|         content_type_field="rel_obj_content_type", |  | ||||||
|         object_id_field="rel_obj_id", |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         verbose_name = _("Blueprint Instance") |         verbose_name = _("Blueprint Instance") | ||||||
|         verbose_name_plural = _("Blueprint Instances") |         verbose_name_plural = _("Blueprint Instances") | ||||||
|  | |||||||
							
								
								
									
										18
									
								
								authentik/blueprints/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								authentik/blueprints/settings.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,18 @@ | |||||||
|  | """blueprint Settings""" | ||||||
|  |  | ||||||
|  | from celery.schedules import crontab | ||||||
|  |  | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
|  |  | ||||||
|  | CELERY_BEAT_SCHEDULE = { | ||||||
|  |     "blueprints_v1_discover": { | ||||||
|  |         "task": "authentik.blueprints.v1.tasks.blueprints_discovery", | ||||||
|  |         "schedule": crontab(minute=fqdn_rand("blueprints_v1_discover"), hour="*"), | ||||||
|  |         "options": {"queue": "authentik_scheduled"}, | ||||||
|  |     }, | ||||||
|  |     "blueprints_v1_cleanup": { | ||||||
|  |         "task": "authentik.blueprints.v1.tasks.clear_failed_blueprints", | ||||||
|  |         "schedule": crontab(minute=fqdn_rand("blueprints_v1_cleanup"), hour="*"), | ||||||
|  |         "options": {"queue": "authentik_scheduled"}, | ||||||
|  |     }, | ||||||
|  | } | ||||||
| @ -1,2 +0,0 @@ | |||||||
| # Import all v1 tasks for auto task discovery |  | ||||||
| from authentik.blueprints.v1.tasks import *  # noqa: F403 |  | ||||||
| @ -54,7 +54,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | |||||||
|             file.seek(0) |             file.seek(0) | ||||||
|             file_hash = sha512(file.read().encode()).hexdigest() |             file_hash = sha512(file.read().encode()).hexdigest() | ||||||
|             file.flush() |             file.flush() | ||||||
|             blueprints_discovery.send() |             blueprints_discovery() | ||||||
|             instance = BlueprintInstance.objects.filter(name=blueprint_id).first() |             instance = BlueprintInstance.objects.filter(name=blueprint_id).first() | ||||||
|             self.assertEqual(instance.last_applied_hash, file_hash) |             self.assertEqual(instance.last_applied_hash, file_hash) | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
| @ -82,7 +82,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | |||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|             file.flush() |             file.flush() | ||||||
|             blueprints_discovery.send() |             blueprints_discovery() | ||||||
|             blueprint = BlueprintInstance.objects.filter(name="foo").first() |             blueprint = BlueprintInstance.objects.filter(name="foo").first() | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 blueprint.last_applied_hash, |                 blueprint.last_applied_hash, | ||||||
| @ -107,7 +107,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | |||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|             file.flush() |             file.flush() | ||||||
|             blueprints_discovery.send() |             blueprints_discovery() | ||||||
|             blueprint.refresh_from_db() |             blueprint.refresh_from_db() | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 blueprint.last_applied_hash, |                 blueprint.last_applied_hash, | ||||||
|  | |||||||
| @ -57,6 +57,7 @@ from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import ( | |||||||
|     EndpointDeviceConnection, |     EndpointDeviceConnection, | ||||||
| ) | ) | ||||||
| from authentik.events.logs import LogEvent, capture_logs | from authentik.events.logs import LogEvent, capture_logs | ||||||
|  | from authentik.events.models import SystemTask | ||||||
| from authentik.events.utils import cleanse_dict | from authentik.events.utils import cleanse_dict | ||||||
| from authentik.flows.models import FlowToken, Stage | from authentik.flows.models import FlowToken, Stage | ||||||
| from authentik.lib.models import SerializerModel | from authentik.lib.models import SerializerModel | ||||||
| @ -76,7 +77,6 @@ from authentik.providers.scim.models import SCIMProviderGroup, SCIMProviderUser | |||||||
| from authentik.rbac.models import Role | from authentik.rbac.models import Role | ||||||
| from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser | from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser | ||||||
| from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType | from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType | ||||||
| from authentik.tasks.models import Task |  | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
|  |  | ||||||
| # Context set when the serializer is created in a blueprint context | # Context set when the serializer is created in a blueprint context | ||||||
| @ -118,7 +118,7 @@ def excluded_models() -> list[type[Model]]: | |||||||
|         SCIMProviderGroup, |         SCIMProviderGroup, | ||||||
|         SCIMProviderUser, |         SCIMProviderUser, | ||||||
|         Tenant, |         Tenant, | ||||||
|         Task, |         SystemTask, | ||||||
|         ConnectionToken, |         ConnectionToken, | ||||||
|         AuthorizationCode, |         AuthorizationCode, | ||||||
|         AccessToken, |         AccessToken, | ||||||
|  | |||||||
| @ -44,7 +44,7 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer): | |||||||
|             return MetaResult() |             return MetaResult() | ||||||
|         LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance) |         LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance) | ||||||
|  |  | ||||||
|         apply_blueprint(self.blueprint_instance.pk) |         apply_blueprint(str(self.blueprint_instance.pk)) | ||||||
|         return MetaResult() |         return MetaResult() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -4,17 +4,12 @@ from dataclasses import asdict, dataclass, field | |||||||
| from hashlib import sha512 | from hashlib import sha512 | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from sys import platform | from sys import platform | ||||||
| from uuid import UUID |  | ||||||
|  |  | ||||||
| from dacite.core import from_dict | from dacite.core import from_dict | ||||||
| from django.conf import settings |  | ||||||
| from django.db import DatabaseError, InternalError, ProgrammingError | from django.db import DatabaseError, InternalError, ProgrammingError | ||||||
| from django.utils.text import slugify | from django.utils.text import slugify | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from django_dramatiq_postgres.middleware import CurrentTask, CurrentTaskNotFound |  | ||||||
| from dramatiq.actor import actor |  | ||||||
| from dramatiq.middleware import Middleware |  | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| from watchdog.events import ( | from watchdog.events import ( | ||||||
|     FileCreatedEvent, |     FileCreatedEvent, | ||||||
| @ -36,13 +31,15 @@ from authentik.blueprints.v1.importer import Importer | |||||||
| from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE | from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE | ||||||
| from authentik.blueprints.v1.oci import OCI_PREFIX | from authentik.blueprints.v1.oci import OCI_PREFIX | ||||||
| from authentik.events.logs import capture_logs | from authentik.events.logs import capture_logs | ||||||
|  | from authentik.events.models import TaskStatus | ||||||
|  | from authentik.events.system_tasks import SystemTask, prefill_task | ||||||
| from authentik.events.utils import sanitize_dict | from authentik.events.utils import sanitize_dict | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.tasks.models import Task | from authentik.root.celery import CELERY_APP | ||||||
| from authentik.tasks.schedules.models import Schedule |  | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  | _file_watcher_started = False | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass | @dataclass | ||||||
| @ -56,9 +53,13 @@ class BlueprintFile: | |||||||
|     meta: BlueprintMetadata | None = field(default=None) |     meta: BlueprintMetadata | None = field(default=None) | ||||||
|  |  | ||||||
|  |  | ||||||
| class BlueprintWatcherMiddleware(Middleware): | def start_blueprint_watcher(): | ||||||
|     def start_blueprint_watcher(self): |     """Start blueprint watcher, if it's not running already.""" | ||||||
|         """Start blueprint watcher""" |     # This function might be called twice since it's called on celery startup | ||||||
|  |  | ||||||
|  |     global _file_watcher_started  # noqa: PLW0603 | ||||||
|  |     if _file_watcher_started: | ||||||
|  |         return | ||||||
|     observer = Observer() |     observer = Observer() | ||||||
|     kwargs = {} |     kwargs = {} | ||||||
|     if platform.startswith("linux"): |     if platform.startswith("linux"): | ||||||
| @ -67,10 +68,7 @@ class BlueprintWatcherMiddleware(Middleware): | |||||||
|         BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs |         BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs | ||||||
|     ) |     ) | ||||||
|     observer.start() |     observer.start() | ||||||
|  |     _file_watcher_started = True | ||||||
|     def after_worker_boot(self, broker, worker): |  | ||||||
|         if not settings.TEST: |  | ||||||
|             self.start_blueprint_watcher() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class BlueprintEventHandler(FileSystemEventHandler): | class BlueprintEventHandler(FileSystemEventHandler): | ||||||
| @ -94,7 +92,7 @@ class BlueprintEventHandler(FileSystemEventHandler): | |||||||
|         LOGGER.debug("new blueprint file created, starting discovery") |         LOGGER.debug("new blueprint file created, starting discovery") | ||||||
|         for tenant in Tenant.objects.filter(ready=True): |         for tenant in Tenant.objects.filter(ready=True): | ||||||
|             with tenant: |             with tenant: | ||||||
|                 Schedule.dispatch_by_actor(blueprints_discovery) |                 blueprints_discovery.delay() | ||||||
|  |  | ||||||
|     def on_modified(self, event: FileSystemEvent): |     def on_modified(self, event: FileSystemEvent): | ||||||
|         """Process file modification""" |         """Process file modification""" | ||||||
| @ -105,14 +103,14 @@ class BlueprintEventHandler(FileSystemEventHandler): | |||||||
|             with tenant: |             with tenant: | ||||||
|                 for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True): |                 for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True): | ||||||
|                     LOGGER.debug("modified blueprint file, starting apply", instance=instance) |                     LOGGER.debug("modified blueprint file, starting apply", instance=instance) | ||||||
|                     apply_blueprint.send_with_options(args=(instance.pk,), rel_obj=instance) |                     apply_blueprint.delay(instance.pk.hex) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor( | @CELERY_APP.task( | ||||||
|     description=_("Find blueprints as `blueprints_find` does, but return a safe dict."), |  | ||||||
|     throws=(DatabaseError, ProgrammingError, InternalError), |     throws=(DatabaseError, ProgrammingError, InternalError), | ||||||
| ) | ) | ||||||
| def blueprints_find_dict(): | def blueprints_find_dict(): | ||||||
|  |     """Find blueprints as `blueprints_find` does, but return a safe dict""" | ||||||
|     blueprints = [] |     blueprints = [] | ||||||
|     for blueprint in blueprints_find(): |     for blueprint in blueprints_find(): | ||||||
|         blueprints.append(sanitize_dict(asdict(blueprint))) |         blueprints.append(sanitize_dict(asdict(blueprint))) | ||||||
| @ -148,19 +146,21 @@ def blueprints_find() -> list[BlueprintFile]: | |||||||
|     return blueprints |     return blueprints | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor( | @CELERY_APP.task( | ||||||
|     description=_("Find blueprints and check if they need to be created in the database."), |     throws=(DatabaseError, ProgrammingError, InternalError), base=SystemTask, bind=True | ||||||
|     throws=(DatabaseError, ProgrammingError, InternalError), |  | ||||||
| ) | ) | ||||||
| def blueprints_discovery(path: str | None = None): | @prefill_task | ||||||
|     self: Task = CurrentTask.get_task() | def blueprints_discovery(self: SystemTask, path: str | None = None): | ||||||
|  |     """Find blueprints and check if they need to be created in the database""" | ||||||
|     count = 0 |     count = 0 | ||||||
|     for blueprint in blueprints_find(): |     for blueprint in blueprints_find(): | ||||||
|         if path and blueprint.path != path: |         if path and blueprint.path != path: | ||||||
|             continue |             continue | ||||||
|         check_blueprint_v1_file(blueprint) |         check_blueprint_v1_file(blueprint) | ||||||
|         count += 1 |         count += 1 | ||||||
|     self.info(f"Successfully imported {count} files.") |     self.set_status( | ||||||
|  |         TaskStatus.SUCCESSFUL, _("Successfully imported {count} files.".format(count=count)) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def check_blueprint_v1_file(blueprint: BlueprintFile): | def check_blueprint_v1_file(blueprint: BlueprintFile): | ||||||
| @ -187,26 +187,22 @@ def check_blueprint_v1_file(blueprint: BlueprintFile): | |||||||
|         ) |         ) | ||||||
|     if instance.last_applied_hash != blueprint.hash: |     if instance.last_applied_hash != blueprint.hash: | ||||||
|         LOGGER.info("Applying blueprint due to changed file", instance=instance, path=instance.path) |         LOGGER.info("Applying blueprint due to changed file", instance=instance, path=instance.path) | ||||||
|         apply_blueprint.send_with_options(args=(instance.pk,), rel_obj=instance) |         apply_blueprint.delay(str(instance.pk)) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Apply single blueprint.")) | @CELERY_APP.task( | ||||||
| def apply_blueprint(instance_pk: UUID): |     bind=True, | ||||||
|     try: |     base=SystemTask, | ||||||
|         self: Task = CurrentTask.get_task() | ) | ||||||
|     except CurrentTaskNotFound: | def apply_blueprint(self: SystemTask, instance_pk: str): | ||||||
|         self = Task() |     """Apply single blueprint""" | ||||||
|     self.set_uid(str(instance_pk)) |     self.save_on_success = False | ||||||
|     instance: BlueprintInstance | None = None |     instance: BlueprintInstance | None = None | ||||||
|     try: |     try: | ||||||
|         instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() |         instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() | ||||||
|         if not instance: |         if not instance or not instance.enabled: | ||||||
|             self.warning(f"Could not find blueprint {instance_pk}, skipping") |  | ||||||
|             return |             return | ||||||
|         self.set_uid(slugify(instance.name)) |         self.set_uid(slugify(instance.name)) | ||||||
|         if not instance.enabled: |  | ||||||
|             self.info(f"Blueprint {instance.name} is disabled, skipping") |  | ||||||
|             return |  | ||||||
|         blueprint_content = instance.retrieve() |         blueprint_content = instance.retrieve() | ||||||
|         file_hash = sha512(blueprint_content.encode()).hexdigest() |         file_hash = sha512(blueprint_content.encode()).hexdigest() | ||||||
|         importer = Importer.from_string(blueprint_content, instance.context) |         importer = Importer.from_string(blueprint_content, instance.context) | ||||||
| @ -216,18 +212,19 @@ def apply_blueprint(instance_pk: UUID): | |||||||
|         if not valid: |         if not valid: | ||||||
|             instance.status = BlueprintInstanceStatus.ERROR |             instance.status = BlueprintInstanceStatus.ERROR | ||||||
|             instance.save() |             instance.save() | ||||||
|             self.logs(logs) |             self.set_status(TaskStatus.ERROR, *logs) | ||||||
|             return |             return | ||||||
|         with capture_logs() as logs: |         with capture_logs() as logs: | ||||||
|             applied = importer.apply() |             applied = importer.apply() | ||||||
|             if not applied: |             if not applied: | ||||||
|                 instance.status = BlueprintInstanceStatus.ERROR |                 instance.status = BlueprintInstanceStatus.ERROR | ||||||
|                 instance.save() |                 instance.save() | ||||||
|                 self.logs(logs) |                 self.set_status(TaskStatus.ERROR, *logs) | ||||||
|                 return |                 return | ||||||
|         instance.status = BlueprintInstanceStatus.SUCCESSFUL |         instance.status = BlueprintInstanceStatus.SUCCESSFUL | ||||||
|         instance.last_applied_hash = file_hash |         instance.last_applied_hash = file_hash | ||||||
|         instance.last_applied = now() |         instance.last_applied = now() | ||||||
|  |         self.set_status(TaskStatus.SUCCESSFUL) | ||||||
|     except ( |     except ( | ||||||
|         OSError, |         OSError, | ||||||
|         DatabaseError, |         DatabaseError, | ||||||
| @ -238,14 +235,15 @@ def apply_blueprint(instance_pk: UUID): | |||||||
|     ) as exc: |     ) as exc: | ||||||
|         if instance: |         if instance: | ||||||
|             instance.status = BlueprintInstanceStatus.ERROR |             instance.status = BlueprintInstanceStatus.ERROR | ||||||
|         self.error(exc) |         self.set_error(exc) | ||||||
|     finally: |     finally: | ||||||
|         if instance: |         if instance: | ||||||
|             instance.save() |             instance.save() | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Remove blueprints which couldn't be fetched.")) | @CELERY_APP.task() | ||||||
| def clear_failed_blueprints(): | def clear_failed_blueprints(): | ||||||
|  |     """Remove blueprints which couldn't be fetched""" | ||||||
|     # Exclude OCI blueprints as those might be temporarily unavailable |     # Exclude OCI blueprints as those might be temporarily unavailable | ||||||
|     for blueprint in BlueprintInstance.objects.exclude(path__startswith=OCI_PREFIX): |     for blueprint in BlueprintInstance.objects.exclude(path__startswith=OCI_PREFIX): | ||||||
|         try: |         try: | ||||||
|  | |||||||
| @ -9,7 +9,6 @@ class AuthentikBrandsConfig(ManagedAppConfig): | |||||||
|     name = "authentik.brands" |     name = "authentik.brands" | ||||||
|     label = "authentik_brands" |     label = "authentik_brands" | ||||||
|     verbose_name = "authentik Brands" |     verbose_name = "authentik Brands" | ||||||
|     default = True |  | ||||||
|     mountpoints = { |     mountpoints = { | ||||||
|         "authentik.brands.urls_root": "", |         "authentik.brands.urls_root": "", | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -1,7 +1,8 @@ | |||||||
| """authentik core app config""" | """authentik core app config""" | ||||||
|  |  | ||||||
|  | from django.conf import settings | ||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
| from authentik.tasks.schedules.lib import ScheduleSpec |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikCoreConfig(ManagedAppConfig): | class AuthentikCoreConfig(ManagedAppConfig): | ||||||
| @ -13,6 +14,14 @@ class AuthentikCoreConfig(ManagedAppConfig): | |||||||
|     mountpoint = "" |     mountpoint = "" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|  |     @ManagedAppConfig.reconcile_global | ||||||
|  |     def debug_worker_hook(self): | ||||||
|  |         """Dispatch startup tasks inline when debugging""" | ||||||
|  |         if settings.DEBUG: | ||||||
|  |             from authentik.root.celery import worker_ready_hook | ||||||
|  |  | ||||||
|  |             worker_ready_hook() | ||||||
|  |  | ||||||
|     @ManagedAppConfig.reconcile_tenant |     @ManagedAppConfig.reconcile_tenant | ||||||
|     def source_inbuilt(self): |     def source_inbuilt(self): | ||||||
|         """Reconcile inbuilt source""" |         """Reconcile inbuilt source""" | ||||||
| @ -25,18 +34,3 @@ class AuthentikCoreConfig(ManagedAppConfig): | |||||||
|             }, |             }, | ||||||
|             managed=Source.MANAGED_INBUILT, |             managed=Source.MANAGED_INBUILT, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def tenant_schedule_specs(self) -> list[ScheduleSpec]: |  | ||||||
|         from authentik.core.tasks import clean_expired_models, clean_temporary_users |  | ||||||
|  |  | ||||||
|         return [ |  | ||||||
|             ScheduleSpec( |  | ||||||
|                 actor=clean_expired_models, |  | ||||||
|                 crontab="2-59/5 * * * *", |  | ||||||
|             ), |  | ||||||
|             ScheduleSpec( |  | ||||||
|                 actor=clean_temporary_users, |  | ||||||
|                 crontab="9-59/5 * * * *", |  | ||||||
|             ), |  | ||||||
|         ] |  | ||||||
|  | |||||||
							
								
								
									
										21
									
								
								authentik/core/management/commands/bootstrap_tasks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								authentik/core/management/commands/bootstrap_tasks.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | |||||||
|  | """Run bootstrap tasks""" | ||||||
|  |  | ||||||
|  | from django.core.management.base import BaseCommand | ||||||
|  | from django_tenants.utils import get_public_schema_name | ||||||
|  |  | ||||||
|  | from authentik.root.celery import _get_startup_tasks_all_tenants, _get_startup_tasks_default_tenant | ||||||
|  | from authentik.tenants.models import Tenant | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Command(BaseCommand): | ||||||
|  |     """Run bootstrap tasks to ensure certain objects are created""" | ||||||
|  |  | ||||||
|  |     def handle(self, **options): | ||||||
|  |         for task in _get_startup_tasks_default_tenant(): | ||||||
|  |             with Tenant.objects.get(schema_name=get_public_schema_name()): | ||||||
|  |                 task() | ||||||
|  |  | ||||||
|  |         for task in _get_startup_tasks_all_tenants(): | ||||||
|  |             for tenant in Tenant.objects.filter(ready=True): | ||||||
|  |                 with tenant: | ||||||
|  |                     task() | ||||||
							
								
								
									
										47
									
								
								authentik/core/management/commands/worker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								authentik/core/management/commands/worker.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,47 @@ | |||||||
|  | """Run worker""" | ||||||
|  |  | ||||||
|  | from sys import exit as sysexit | ||||||
|  | from tempfile import tempdir | ||||||
|  |  | ||||||
|  | from celery.apps.worker import Worker | ||||||
|  | from django.core.management.base import BaseCommand | ||||||
|  | from django.db import close_old_connections | ||||||
|  | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
|  | from authentik.lib.config import CONFIG | ||||||
|  | from authentik.lib.debug import start_debug_server | ||||||
|  | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
|  | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Command(BaseCommand): | ||||||
|  |     """Run worker""" | ||||||
|  |  | ||||||
|  |     def add_arguments(self, parser): | ||||||
|  |         parser.add_argument( | ||||||
|  |             "-b", | ||||||
|  |             "--beat", | ||||||
|  |             action="store_false", | ||||||
|  |             help="When set, this worker will _not_ run Beat (scheduled) tasks", | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def handle(self, **options): | ||||||
|  |         LOGGER.debug("Celery options", **options) | ||||||
|  |         close_old_connections() | ||||||
|  |         start_debug_server() | ||||||
|  |         worker: Worker = CELERY_APP.Worker( | ||||||
|  |             no_color=False, | ||||||
|  |             quiet=True, | ||||||
|  |             optimization="fair", | ||||||
|  |             autoscale=(CONFIG.get_int("worker.concurrency"), 1), | ||||||
|  |             task_events=True, | ||||||
|  |             beat=options.get("beat", True), | ||||||
|  |             schedule_filename=f"{tempdir}/celerybeat-schedule", | ||||||
|  |             queues=["authentik", "authentik_scheduled", "authentik_events"], | ||||||
|  |         ) | ||||||
|  |         for task in CELERY_APP.tasks: | ||||||
|  |             LOGGER.debug("Registered task", task=task) | ||||||
|  |  | ||||||
|  |         worker.start() | ||||||
|  |         sysexit(worker.exitcode) | ||||||
| @ -3,9 +3,6 @@ | |||||||
| from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||||
|  |  | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from django.utils.translation import gettext_lazy as _ |  | ||||||
| from django_dramatiq_postgres.middleware import CurrentTask |  | ||||||
| from dramatiq.actor import actor |  | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.models import ( | from authentik.core.models import ( | ||||||
| @ -14,14 +11,17 @@ from authentik.core.models import ( | |||||||
|     ExpiringModel, |     ExpiringModel, | ||||||
|     User, |     User, | ||||||
| ) | ) | ||||||
| from authentik.tasks.models import Task | from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task | ||||||
|  | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Remove expired objects.")) | @CELERY_APP.task(bind=True, base=SystemTask) | ||||||
| def clean_expired_models(): | @prefill_task | ||||||
|     self: Task = CurrentTask.get_task() | def clean_expired_models(self: SystemTask): | ||||||
|  |     """Remove expired objects""" | ||||||
|  |     messages = [] | ||||||
|     for cls in ExpiringModel.__subclasses__(): |     for cls in ExpiringModel.__subclasses__(): | ||||||
|         cls: ExpiringModel |         cls: ExpiringModel | ||||||
|         objects = ( |         objects = ( | ||||||
| @ -31,13 +31,16 @@ def clean_expired_models(): | |||||||
|         for obj in objects: |         for obj in objects: | ||||||
|             obj.expire_action() |             obj.expire_action() | ||||||
|         LOGGER.debug("Expired models", model=cls, amount=amount) |         LOGGER.debug("Expired models", model=cls, amount=amount) | ||||||
|         self.info(f"Expired {amount} {cls._meta.verbose_name_plural}") |         messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}") | ||||||
|  |     self.set_status(TaskStatus.SUCCESSFUL, *messages) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Remove temporary users created by SAML Sources.")) | @CELERY_APP.task(bind=True, base=SystemTask) | ||||||
| def clean_temporary_users(): | @prefill_task | ||||||
|     self: Task = CurrentTask.get_task() | def clean_temporary_users(self: SystemTask): | ||||||
|  |     """Remove temporary users created by SAML Sources""" | ||||||
|     _now = datetime.now() |     _now = datetime.now() | ||||||
|  |     messages = [] | ||||||
|     deleted_users = 0 |     deleted_users = 0 | ||||||
|     for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}): |     for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}): | ||||||
|         if not user.attributes.get(USER_ATTRIBUTE_EXPIRES): |         if not user.attributes.get(USER_ATTRIBUTE_EXPIRES): | ||||||
| @ -49,4 +52,5 @@ def clean_temporary_users(): | |||||||
|             LOGGER.debug("User is expired and will be deleted.", user=user, delta=delta) |             LOGGER.debug("User is expired and will be deleted.", user=user, delta=delta) | ||||||
|             user.delete() |             user.delete() | ||||||
|             deleted_users += 1 |             deleted_users += 1 | ||||||
|     self.info(f"Successfully deleted {deleted_users} users.") |     messages.append(f"Successfully deleted {deleted_users} users.") | ||||||
|  |     self.set_status(TaskStatus.SUCCESSFUL, *messages) | ||||||
|  | |||||||
| @ -36,7 +36,7 @@ class TestTasks(APITestCase): | |||||||
|             expires=now(), user=get_anonymous_user(), intent=TokenIntents.INTENT_API |             expires=now(), user=get_anonymous_user(), intent=TokenIntents.INTENT_API | ||||||
|         ) |         ) | ||||||
|         key = token.key |         key = token.key | ||||||
|         clean_expired_models.send() |         clean_expired_models.delay().get() | ||||||
|         token.refresh_from_db() |         token.refresh_from_db() | ||||||
|         self.assertNotEqual(key, token.key) |         self.assertNotEqual(key, token.key) | ||||||
|  |  | ||||||
| @ -50,5 +50,5 @@ class TestTasks(APITestCase): | |||||||
|                 USER_ATTRIBUTE_EXPIRES: mktime(now().timetuple()), |                 USER_ATTRIBUTE_EXPIRES: mktime(now().timetuple()), | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|         clean_temporary_users.send() |         clean_temporary_users.delay().get() | ||||||
|         self.assertFalse(User.objects.filter(username=username)) |         self.assertFalse(User.objects.filter(username=username)) | ||||||
|  | |||||||
| @ -4,8 +4,6 @@ from datetime import UTC, datetime | |||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
| from authentik.tasks.schedules.lib import ScheduleSpec |  | ||||||
|  |  | ||||||
| MANAGED_KEY = "goauthentik.io/crypto/jwt-managed" | MANAGED_KEY = "goauthentik.io/crypto/jwt-managed" | ||||||
|  |  | ||||||
| @ -69,14 +67,3 @@ class AuthentikCryptoConfig(ManagedAppConfig): | |||||||
|                 "key_data": builder.private_key, |                 "key_data": builder.private_key, | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def tenant_schedule_specs(self) -> list[ScheduleSpec]: |  | ||||||
|         from authentik.crypto.tasks import certificate_discovery |  | ||||||
|  |  | ||||||
|         return [ |  | ||||||
|             ScheduleSpec( |  | ||||||
|                 actor=certificate_discovery, |  | ||||||
|                 crontab=f"{fqdn_rand('crypto_certificate_discovery')} * * * *", |  | ||||||
|             ), |  | ||||||
|         ] |  | ||||||
|  | |||||||
							
								
								
									
										13
									
								
								authentik/crypto/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								authentik/crypto/settings.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,13 @@ | |||||||
|  | """Crypto task Settings""" | ||||||
|  |  | ||||||
|  | from celery.schedules import crontab | ||||||
|  |  | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
|  |  | ||||||
|  | CELERY_BEAT_SCHEDULE = { | ||||||
|  |     "crypto_certificate_discovery": { | ||||||
|  |         "task": "authentik.crypto.tasks.certificate_discovery", | ||||||
|  |         "schedule": crontab(minute=fqdn_rand("crypto_certificate_discovery"), hour="*"), | ||||||
|  |         "options": {"queue": "authentik_scheduled"}, | ||||||
|  |     }, | ||||||
|  | } | ||||||
| @ -7,13 +7,13 @@ from cryptography.hazmat.backends import default_backend | |||||||
| from cryptography.hazmat.primitives.serialization import load_pem_private_key | from cryptography.hazmat.primitives.serialization import load_pem_private_key | ||||||
| from cryptography.x509.base import load_pem_x509_certificate | from cryptography.x509.base import load_pem_x509_certificate | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from django_dramatiq_postgres.middleware import CurrentTask |  | ||||||
| from dramatiq.actor import actor |  | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
|  | from authentik.events.models import TaskStatus | ||||||
|  | from authentik.events.system_tasks import SystemTask, prefill_task | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.tasks.models import Task | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| @ -36,9 +36,10 @@ def ensure_certificate_valid(body: str): | |||||||
|     return body |     return body | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Discover, import and update certificates from the filesystem.")) | @CELERY_APP.task(bind=True, base=SystemTask) | ||||||
| def certificate_discovery(): | @prefill_task | ||||||
|     self: Task = CurrentTask.get_task() | def certificate_discovery(self: SystemTask): | ||||||
|  |     """Discover, import and update certificates from the filesystem""" | ||||||
|     certs = {} |     certs = {} | ||||||
|     private_keys = {} |     private_keys = {} | ||||||
|     discovered = 0 |     discovered = 0 | ||||||
| @ -83,4 +84,6 @@ def certificate_discovery(): | |||||||
|                 dirty = True |                 dirty = True | ||||||
|         if dirty: |         if dirty: | ||||||
|             cert.save() |             cert.save() | ||||||
|     self.info(f"Successfully imported {discovered} files.") |     self.set_status( | ||||||
|  |         TaskStatus.SUCCESSFUL, _("Successfully imported {count} files.".format(count=discovered)) | ||||||
|  |     ) | ||||||
|  | |||||||
| @ -338,7 +338,7 @@ class TestCrypto(APITestCase): | |||||||
|             with open(f"{temp_dir}/foo.bar/privkey.pem", "w+", encoding="utf-8") as _key: |             with open(f"{temp_dir}/foo.bar/privkey.pem", "w+", encoding="utf-8") as _key: | ||||||
|                 _key.write(builder.private_key) |                 _key.write(builder.private_key) | ||||||
|             with CONFIG.patch("cert_discovery_dir", temp_dir): |             with CONFIG.patch("cert_discovery_dir", temp_dir): | ||||||
|                 certificate_discovery.send() |                 certificate_discovery() | ||||||
|         keypair: CertificateKeyPair = CertificateKeyPair.objects.filter( |         keypair: CertificateKeyPair = CertificateKeyPair.objects.filter( | ||||||
|             managed=MANAGED_DISCOVERED % "foo" |             managed=MANAGED_DISCOVERED % "foo" | ||||||
|         ).first() |         ).first() | ||||||
|  | |||||||
| @ -3,8 +3,6 @@ | |||||||
| from django.conf import settings | from django.conf import settings | ||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
| from authentik.tasks.schedules.lib import ScheduleSpec |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class EnterpriseConfig(ManagedAppConfig): | class EnterpriseConfig(ManagedAppConfig): | ||||||
| @ -28,14 +26,3 @@ class AuthentikEnterpriseConfig(EnterpriseConfig): | |||||||
|         from authentik.enterprise.license import LicenseKey |         from authentik.enterprise.license import LicenseKey | ||||||
|  |  | ||||||
|         return LicenseKey.cached_summary().status.is_valid |         return LicenseKey.cached_summary().status.is_valid | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def tenant_schedule_specs(self) -> list[ScheduleSpec]: |  | ||||||
|         from authentik.enterprise.tasks import enterprise_update_usage |  | ||||||
|  |  | ||||||
|         return [ |  | ||||||
|             ScheduleSpec( |  | ||||||
|                 actor=enterprise_update_usage, |  | ||||||
|                 crontab=f"{fqdn_rand('enterprise_update_usage')} */2 * * *", |  | ||||||
|             ), |  | ||||||
|         ] |  | ||||||
|  | |||||||
| @ -1,8 +1,6 @@ | |||||||
| """authentik Unique Password policy app config""" | """authentik Unique Password policy app config""" | ||||||
|  |  | ||||||
| from authentik.enterprise.apps import EnterpriseConfig | from authentik.enterprise.apps import EnterpriseConfig | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
| from authentik.tasks.schedules.lib import ScheduleSpec |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig): | class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig): | ||||||
| @ -10,21 +8,3 @@ class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig): | |||||||
|     label = "authentik_policies_unique_password" |     label = "authentik_policies_unique_password" | ||||||
|     verbose_name = "authentik Enterprise.Policies.Unique Password" |     verbose_name = "authentik Enterprise.Policies.Unique Password" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def tenant_schedule_specs(self) -> list[ScheduleSpec]: |  | ||||||
|         from authentik.enterprise.policies.unique_password.tasks import ( |  | ||||||
|             check_and_purge_password_history, |  | ||||||
|             trim_password_histories, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         return [ |  | ||||||
|             ScheduleSpec( |  | ||||||
|                 actor=trim_password_histories, |  | ||||||
|                 crontab=f"{fqdn_rand('policies_unique_password_trim')} */12 * * *", |  | ||||||
|             ), |  | ||||||
|             ScheduleSpec( |  | ||||||
|                 actor=check_and_purge_password_history, |  | ||||||
|                 crontab=f"{fqdn_rand('policies_unique_password_purge')} */24 * * *", |  | ||||||
|             ), |  | ||||||
|         ] |  | ||||||
|  | |||||||
							
								
								
									
										20
									
								
								authentik/enterprise/policies/unique_password/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								authentik/enterprise/policies/unique_password/settings.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,20 @@ | |||||||
|  | """Unique Password Policy settings""" | ||||||
|  |  | ||||||
|  | from celery.schedules import crontab | ||||||
|  |  | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
|  |  | ||||||
|  | CELERY_BEAT_SCHEDULE = { | ||||||
|  |     "policies_unique_password_trim_history": { | ||||||
|  |         "task": "authentik.enterprise.policies.unique_password.tasks.trim_password_histories", | ||||||
|  |         "schedule": crontab(minute=fqdn_rand("policies_unique_password_trim"), hour="*/12"), | ||||||
|  |         "options": {"queue": "authentik_scheduled"}, | ||||||
|  |     }, | ||||||
|  |     "policies_unique_password_check_purge": { | ||||||
|  |         "task": ( | ||||||
|  |             "authentik.enterprise.policies.unique_password.tasks.check_and_purge_password_history" | ||||||
|  |         ), | ||||||
|  |         "schedule": crontab(minute=fqdn_rand("policies_unique_password_purge"), hour="*/24"), | ||||||
|  |         "options": {"queue": "authentik_scheduled"}, | ||||||
|  |     }, | ||||||
|  | } | ||||||
| @ -1,37 +1,35 @@ | |||||||
| from django.db.models.aggregates import Count | from django.db.models.aggregates import Count | ||||||
| from django.utils.translation import gettext_lazy as _ |  | ||||||
| from django_dramatiq_postgres.middleware import CurrentTask |  | ||||||
| from dramatiq.actor import actor |  | ||||||
| from structlog import get_logger | from structlog import get_logger | ||||||
|  |  | ||||||
| from authentik.enterprise.policies.unique_password.models import ( | from authentik.enterprise.policies.unique_password.models import ( | ||||||
|     UniquePasswordPolicy, |     UniquePasswordPolicy, | ||||||
|     UserPasswordHistory, |     UserPasswordHistory, | ||||||
| ) | ) | ||||||
| from authentik.tasks.models import Task | from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task | ||||||
|  | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor( | @CELERY_APP.task(bind=True, base=SystemTask) | ||||||
|     description=_( | @prefill_task | ||||||
|         "Check if any UniquePasswordPolicy exists, and if not, purge the password history table." | def check_and_purge_password_history(self: SystemTask): | ||||||
|     ) |     """Check if any UniquePasswordPolicy exists, and if not, purge the password history table. | ||||||
| ) |     This is run on a schedule instead of being triggered by policy binding deletion. | ||||||
| def check_and_purge_password_history(): |     """ | ||||||
|     self: Task = CurrentTask.get_task() |  | ||||||
|  |  | ||||||
|     if not UniquePasswordPolicy.objects.exists(): |     if not UniquePasswordPolicy.objects.exists(): | ||||||
|         UserPasswordHistory.objects.all().delete() |         UserPasswordHistory.objects.all().delete() | ||||||
|         LOGGER.debug("Purged UserPasswordHistory table as no policies are in use") |         LOGGER.debug("Purged UserPasswordHistory table as no policies are in use") | ||||||
|         self.info("Successfully purged UserPasswordHistory") |         self.set_status(TaskStatus.SUCCESSFUL, "Successfully purged UserPasswordHistory") | ||||||
|         return |         return | ||||||
|  |  | ||||||
|     self.info("Not purging password histories, a unique password policy exists") |     self.set_status( | ||||||
|  |         TaskStatus.SUCCESSFUL, "Not purging password histories, a unique password policy exists" | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Remove user password history that are too old.")) | @CELERY_APP.task(bind=True, base=SystemTask) | ||||||
| def trim_password_histories(): | def trim_password_histories(self: SystemTask): | ||||||
|     """Removes rows from UserPasswordHistory older than |     """Removes rows from UserPasswordHistory older than | ||||||
|     the `n` most recent entries. |     the `n` most recent entries. | ||||||
|  |  | ||||||
| @ -39,8 +37,6 @@ def trim_password_histories(): | |||||||
|     UniquePasswordPolicy policies. |     UniquePasswordPolicy policies. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     self: Task = CurrentTask.get_task() |  | ||||||
|  |  | ||||||
|     # No policy, we'll let the cleanup above do its thing |     # No policy, we'll let the cleanup above do its thing | ||||||
|     if not UniquePasswordPolicy.objects.exists(): |     if not UniquePasswordPolicy.objects.exists(): | ||||||
|         return |         return | ||||||
| @ -67,4 +63,4 @@ def trim_password_histories(): | |||||||
|  |  | ||||||
|     num_deleted, _ = UserPasswordHistory.objects.exclude(pk__in=all_pks_to_keep).delete() |     num_deleted, _ = UserPasswordHistory.objects.exclude(pk__in=all_pks_to_keep).delete() | ||||||
|     LOGGER.debug("Deleted stale password history records", count=num_deleted) |     LOGGER.debug("Deleted stale password history records", count=num_deleted) | ||||||
|     self.info(f"Delete {num_deleted} stale password history records") |     self.set_status(TaskStatus.SUCCESSFUL, f"Delete {num_deleted} stale password history records") | ||||||
|  | |||||||
| @ -76,7 +76,7 @@ class TestCheckAndPurgePasswordHistory(TestCase): | |||||||
|         self.assertTrue(UserPasswordHistory.objects.exists()) |         self.assertTrue(UserPasswordHistory.objects.exists()) | ||||||
|  |  | ||||||
|         # Run the task - should purge since no policy is in use |         # Run the task - should purge since no policy is in use | ||||||
|         check_and_purge_password_history.send() |         check_and_purge_password_history() | ||||||
|  |  | ||||||
|         # Verify the table is empty |         # Verify the table is empty | ||||||
|         self.assertFalse(UserPasswordHistory.objects.exists()) |         self.assertFalse(UserPasswordHistory.objects.exists()) | ||||||
| @ -99,7 +99,7 @@ class TestCheckAndPurgePasswordHistory(TestCase): | |||||||
|         self.assertTrue(UserPasswordHistory.objects.exists()) |         self.assertTrue(UserPasswordHistory.objects.exists()) | ||||||
|  |  | ||||||
|         # Run the task - should NOT purge since a policy is in use |         # Run the task - should NOT purge since a policy is in use | ||||||
|         check_and_purge_password_history.send() |         check_and_purge_password_history() | ||||||
|  |  | ||||||
|         # Verify the entries still exist |         # Verify the entries still exist | ||||||
|         self.assertTrue(UserPasswordHistory.objects.exists()) |         self.assertTrue(UserPasswordHistory.objects.exists()) | ||||||
| @ -142,7 +142,7 @@ class TestTrimPasswordHistory(TestCase): | |||||||
|             enabled=True, |             enabled=True, | ||||||
|             order=0, |             order=0, | ||||||
|         ) |         ) | ||||||
|         trim_password_histories.send() |         trim_password_histories.delay() | ||||||
|         user_pwd_history_qs = UserPasswordHistory.objects.filter(user=self.user) |         user_pwd_history_qs = UserPasswordHistory.objects.filter(user=self.user) | ||||||
|         self.assertEqual(len(user_pwd_history_qs), 1) |         self.assertEqual(len(user_pwd_history_qs), 1) | ||||||
|  |  | ||||||
| @ -159,7 +159,7 @@ class TestTrimPasswordHistory(TestCase): | |||||||
|             enabled=False, |             enabled=False, | ||||||
|             order=0, |             order=0, | ||||||
|         ) |         ) | ||||||
|         trim_password_histories.send() |         trim_password_histories.delay() | ||||||
|         self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists()) |         self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists()) | ||||||
|  |  | ||||||
|     def test_trim_password_history_fewer_records_than_maximum_is_no_op(self): |     def test_trim_password_history_fewer_records_than_maximum_is_no_op(self): | ||||||
| @ -174,5 +174,5 @@ class TestTrimPasswordHistory(TestCase): | |||||||
|             enabled=True, |             enabled=True, | ||||||
|             order=0, |             order=0, | ||||||
|         ) |         ) | ||||||
|         trim_password_histories.send() |         trim_password_histories.delay() | ||||||
|         self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists()) |         self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists()) | ||||||
|  | |||||||
| @ -55,5 +55,5 @@ class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixi | |||||||
|     ] |     ] | ||||||
|     search_fields = ["name"] |     search_fields = ["name"] | ||||||
|     ordering = ["name"] |     ordering = ["name"] | ||||||
|     sync_task = google_workspace_sync |     sync_single_task = google_workspace_sync | ||||||
|     sync_objects_task = google_workspace_sync_objects |     sync_objects_task = google_workspace_sync_objects | ||||||
|  | |||||||
| @ -7,7 +7,6 @@ from django.db import models | |||||||
| from django.db.models import QuerySet | from django.db.models import QuerySet | ||||||
| from django.templatetags.static import static | from django.templatetags.static import static | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from dramatiq.actor import Actor |  | ||||||
| from google.oauth2.service_account import Credentials | from google.oauth2.service_account import Credentials | ||||||
| from rest_framework.serializers import Serializer | from rest_framework.serializers import Serializer | ||||||
|  |  | ||||||
| @ -111,12 +110,6 @@ class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider): | |||||||
|         help_text=_("Property mappings used for group creation/updating."), |         help_text=_("Property mappings used for group creation/updating."), | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def sync_actor(self) -> Actor: |  | ||||||
|         from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync |  | ||||||
|  |  | ||||||
|         return google_workspace_sync |  | ||||||
|  |  | ||||||
|     def client_for_model( |     def client_for_model( | ||||||
|         self, |         self, | ||||||
|         model: type[User | Group | GoogleWorkspaceProviderUser | GoogleWorkspaceProviderGroup], |         model: type[User | Group | GoogleWorkspaceProviderUser | GoogleWorkspaceProviderGroup], | ||||||
|  | |||||||
							
								
								
									
										13
									
								
								authentik/enterprise/providers/google_workspace/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								authentik/enterprise/providers/google_workspace/settings.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,13 @@ | |||||||
|  | """Google workspace provider task Settings""" | ||||||
|  |  | ||||||
|  | from celery.schedules import crontab | ||||||
|  |  | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
|  |  | ||||||
|  | CELERY_BEAT_SCHEDULE = { | ||||||
|  |     "providers_google_workspace_sync": { | ||||||
|  |         "task": "authentik.enterprise.providers.google_workspace.tasks.google_workspace_sync_all", | ||||||
|  |         "schedule": crontab(minute=fqdn_rand("google_workspace_sync_all"), hour="*/4"), | ||||||
|  |         "options": {"queue": "authentik_scheduled"}, | ||||||
|  |     }, | ||||||
|  | } | ||||||
| @ -2,13 +2,15 @@ | |||||||
|  |  | ||||||
| from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider | from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider | ||||||
| from authentik.enterprise.providers.google_workspace.tasks import ( | from authentik.enterprise.providers.google_workspace.tasks import ( | ||||||
|     google_workspace_sync_direct_dispatch, |     google_workspace_sync, | ||||||
|     google_workspace_sync_m2m_dispatch, |     google_workspace_sync_direct, | ||||||
|  |     google_workspace_sync_m2m, | ||||||
| ) | ) | ||||||
| from authentik.lib.sync.outgoing.signals import register_signals | from authentik.lib.sync.outgoing.signals import register_signals | ||||||
|  |  | ||||||
| register_signals( | register_signals( | ||||||
|     GoogleWorkspaceProvider, |     GoogleWorkspaceProvider, | ||||||
|     task_sync_direct_dispatch=google_workspace_sync_direct_dispatch, |     task_sync_single=google_workspace_sync, | ||||||
|     task_sync_m2m_dispatch=google_workspace_sync_m2m_dispatch, |     task_sync_direct=google_workspace_sync_direct, | ||||||
|  |     task_sync_m2m=google_workspace_sync_m2m, | ||||||
| ) | ) | ||||||
|  | |||||||
| @ -1,48 +1,37 @@ | |||||||
| """Google Provider tasks""" | """Google Provider tasks""" | ||||||
|  |  | ||||||
| from django.utils.translation import gettext_lazy as _ |  | ||||||
| from dramatiq.actor import actor |  | ||||||
|  |  | ||||||
| from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider | from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider | ||||||
|  | from authentik.events.system_tasks import SystemTask | ||||||
|  | from authentik.lib.sync.outgoing.exceptions import TransientSyncException | ||||||
| from authentik.lib.sync.outgoing.tasks import SyncTasks | from authentik.lib.sync.outgoing.tasks import SyncTasks | ||||||
|  | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| sync_tasks = SyncTasks(GoogleWorkspaceProvider) | sync_tasks = SyncTasks(GoogleWorkspaceProvider) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Sync Google Workspace provider objects.")) | @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | ||||||
| def google_workspace_sync_objects(*args, **kwargs): | def google_workspace_sync_objects(*args, **kwargs): | ||||||
|     return sync_tasks.sync_objects(*args, **kwargs) |     return sync_tasks.sync_objects(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Full sync for Google Workspace provider.")) | @CELERY_APP.task( | ||||||
| def google_workspace_sync(provider_pk: int, *args, **kwargs): |     base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True | ||||||
|  | ) | ||||||
|  | def google_workspace_sync(self, provider_pk: int, *args, **kwargs): | ||||||
|     """Run full sync for Google Workspace provider""" |     """Run full sync for Google Workspace provider""" | ||||||
|     return sync_tasks.sync(provider_pk, google_workspace_sync_objects) |     return sync_tasks.sync_single(self, provider_pk, google_workspace_sync_objects) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Sync a direct object (user, group) for Google Workspace provider.")) | @CELERY_APP.task() | ||||||
|  | def google_workspace_sync_all(): | ||||||
|  |     return sync_tasks.sync_all(google_workspace_sync) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | ||||||
| def google_workspace_sync_direct(*args, **kwargs): | def google_workspace_sync_direct(*args, **kwargs): | ||||||
|     return sync_tasks.sync_signal_direct(*args, **kwargs) |     return sync_tasks.sync_signal_direct(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor( | @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | ||||||
|     description=_( |  | ||||||
|         "Dispatch syncs for a direct object (user, group) for Google Workspace providers." |  | ||||||
|     ) |  | ||||||
| ) |  | ||||||
| def google_workspace_sync_direct_dispatch(*args, **kwargs): |  | ||||||
|     return sync_tasks.sync_signal_direct_dispatch(google_workspace_sync_direct, *args, **kwargs) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Sync a related object (memberships) for Google Workspace provider.")) |  | ||||||
| def google_workspace_sync_m2m(*args, **kwargs): | def google_workspace_sync_m2m(*args, **kwargs): | ||||||
|     return sync_tasks.sync_signal_m2m(*args, **kwargs) |     return sync_tasks.sync_signal_m2m(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor( |  | ||||||
|     description=_( |  | ||||||
|         "Dispatch syncs for a related object (memberships) for Google Workspace providers." |  | ||||||
|     ) |  | ||||||
| ) |  | ||||||
| def google_workspace_sync_m2m_dispatch(*args, **kwargs): |  | ||||||
|     return sync_tasks.sync_signal_m2m_dispatch(google_workspace_sync_m2m, *args, **kwargs) |  | ||||||
|  | |||||||
| @ -324,7 +324,7 @@ class GoogleWorkspaceGroupTests(TestCase): | |||||||
|             "authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials", |             "authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials", | ||||||
|             MagicMock(return_value={"developerKey": self.api_key, "http": http}), |             MagicMock(return_value={"developerKey": self.api_key, "http": http}), | ||||||
|         ): |         ): | ||||||
|             google_workspace_sync.send(self.provider.pk).get_result() |             google_workspace_sync.delay(self.provider.pk).get() | ||||||
|             self.assertTrue( |             self.assertTrue( | ||||||
|                 GoogleWorkspaceProviderGroup.objects.filter( |                 GoogleWorkspaceProviderGroup.objects.filter( | ||||||
|                     group=different_group, provider=self.provider |                     group=different_group, provider=self.provider | ||||||
|  | |||||||
| @ -302,7 +302,7 @@ class GoogleWorkspaceUserTests(TestCase): | |||||||
|             "authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials", |             "authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials", | ||||||
|             MagicMock(return_value={"developerKey": self.api_key, "http": http}), |             MagicMock(return_value={"developerKey": self.api_key, "http": http}), | ||||||
|         ): |         ): | ||||||
|             google_workspace_sync.send(self.provider.pk).get_result() |             google_workspace_sync.delay(self.provider.pk).get() | ||||||
|             self.assertTrue( |             self.assertTrue( | ||||||
|                 GoogleWorkspaceProviderUser.objects.filter( |                 GoogleWorkspaceProviderUser.objects.filter( | ||||||
|                     user=different_user, provider=self.provider |                     user=different_user, provider=self.provider | ||||||
|  | |||||||
| @ -53,5 +53,5 @@ class MicrosoftEntraProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin | |||||||
|     ] |     ] | ||||||
|     search_fields = ["name"] |     search_fields = ["name"] | ||||||
|     ordering = ["name"] |     ordering = ["name"] | ||||||
|     sync_task = microsoft_entra_sync |     sync_single_task = microsoft_entra_sync | ||||||
|     sync_objects_task = microsoft_entra_sync_objects |     sync_objects_task = microsoft_entra_sync_objects | ||||||
|  | |||||||
| @ -8,7 +8,6 @@ from django.db import models | |||||||
| from django.db.models import QuerySet | from django.db.models import QuerySet | ||||||
| from django.templatetags.static import static | from django.templatetags.static import static | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from dramatiq.actor import Actor |  | ||||||
| from rest_framework.serializers import Serializer | from rest_framework.serializers import Serializer | ||||||
|  |  | ||||||
| from authentik.core.models import ( | from authentik.core.models import ( | ||||||
| @ -100,12 +99,6 @@ class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider): | |||||||
|         help_text=_("Property mappings used for group creation/updating."), |         help_text=_("Property mappings used for group creation/updating."), | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def sync_actor(self) -> Actor: |  | ||||||
|         from authentik.enterprise.providers.microsoft_entra.tasks import microsoft_entra_sync |  | ||||||
|  |  | ||||||
|         return microsoft_entra_sync |  | ||||||
|  |  | ||||||
|     def client_for_model( |     def client_for_model( | ||||||
|         self, |         self, | ||||||
|         model: type[User | Group | MicrosoftEntraProviderUser | MicrosoftEntraProviderGroup], |         model: type[User | Group | MicrosoftEntraProviderUser | MicrosoftEntraProviderGroup], | ||||||
|  | |||||||
							
								
								
									
										13
									
								
								authentik/enterprise/providers/microsoft_entra/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								authentik/enterprise/providers/microsoft_entra/settings.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,13 @@ | |||||||
|  | """Microsoft Entra provider task Settings""" | ||||||
|  |  | ||||||
|  | from celery.schedules import crontab | ||||||
|  |  | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
|  |  | ||||||
|  | CELERY_BEAT_SCHEDULE = { | ||||||
|  |     "providers_microsoft_entra_sync": { | ||||||
|  |         "task": "authentik.enterprise.providers.microsoft_entra.tasks.microsoft_entra_sync_all", | ||||||
|  |         "schedule": crontab(minute=fqdn_rand("microsoft_entra_sync_all"), hour="*/4"), | ||||||
|  |         "options": {"queue": "authentik_scheduled"}, | ||||||
|  |     }, | ||||||
|  | } | ||||||
| @ -2,13 +2,15 @@ | |||||||
|  |  | ||||||
| from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider | from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider | ||||||
| from authentik.enterprise.providers.microsoft_entra.tasks import ( | from authentik.enterprise.providers.microsoft_entra.tasks import ( | ||||||
|     microsoft_entra_sync_direct_dispatch, |     microsoft_entra_sync, | ||||||
|     microsoft_entra_sync_m2m_dispatch, |     microsoft_entra_sync_direct, | ||||||
|  |     microsoft_entra_sync_m2m, | ||||||
| ) | ) | ||||||
| from authentik.lib.sync.outgoing.signals import register_signals | from authentik.lib.sync.outgoing.signals import register_signals | ||||||
|  |  | ||||||
| register_signals( | register_signals( | ||||||
|     MicrosoftEntraProvider, |     MicrosoftEntraProvider, | ||||||
|     task_sync_direct_dispatch=microsoft_entra_sync_direct_dispatch, |     task_sync_single=microsoft_entra_sync, | ||||||
|     task_sync_m2m_dispatch=microsoft_entra_sync_m2m_dispatch, |     task_sync_direct=microsoft_entra_sync_direct, | ||||||
|  |     task_sync_m2m=microsoft_entra_sync_m2m, | ||||||
| ) | ) | ||||||
|  | |||||||
| @ -1,46 +1,37 @@ | |||||||
| """Microsoft Entra Provider tasks""" | """Microsoft Entra Provider tasks""" | ||||||
|  |  | ||||||
| from django.utils.translation import gettext_lazy as _ |  | ||||||
| from dramatiq.actor import actor |  | ||||||
|  |  | ||||||
| from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider | from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider | ||||||
|  | from authentik.events.system_tasks import SystemTask | ||||||
|  | from authentik.lib.sync.outgoing.exceptions import TransientSyncException | ||||||
| from authentik.lib.sync.outgoing.tasks import SyncTasks | from authentik.lib.sync.outgoing.tasks import SyncTasks | ||||||
|  | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| sync_tasks = SyncTasks(MicrosoftEntraProvider) | sync_tasks = SyncTasks(MicrosoftEntraProvider) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Sync Microsoft Entra provider objects.")) | @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | ||||||
| def microsoft_entra_sync_objects(*args, **kwargs): | def microsoft_entra_sync_objects(*args, **kwargs): | ||||||
|     return sync_tasks.sync_objects(*args, **kwargs) |     return sync_tasks.sync_objects(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Full sync for Microsoft Entra provider.")) | @CELERY_APP.task( | ||||||
| def microsoft_entra_sync(provider_pk: int, *args, **kwargs): |     base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True | ||||||
|  | ) | ||||||
|  | def microsoft_entra_sync(self, provider_pk: int, *args, **kwargs): | ||||||
|     """Run full sync for Microsoft Entra provider""" |     """Run full sync for Microsoft Entra provider""" | ||||||
|     return sync_tasks.sync(provider_pk, microsoft_entra_sync_objects) |     return sync_tasks.sync_single(self, provider_pk, microsoft_entra_sync_objects) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Sync a direct object (user, group) for Microsoft Entra provider.")) | @CELERY_APP.task() | ||||||
|  | def microsoft_entra_sync_all(): | ||||||
|  |     return sync_tasks.sync_all(microsoft_entra_sync) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | ||||||
| def microsoft_entra_sync_direct(*args, **kwargs): | def microsoft_entra_sync_direct(*args, **kwargs): | ||||||
|     return sync_tasks.sync_signal_direct(*args, **kwargs) |     return sync_tasks.sync_signal_direct(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor( | @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | ||||||
|     description=_("Dispatch syncs for a direct object (user, group) for Microsoft Entra providers.") |  | ||||||
| ) |  | ||||||
| def microsoft_entra_sync_direct_dispatch(*args, **kwargs): |  | ||||||
|     return sync_tasks.sync_signal_direct_dispatch(microsoft_entra_sync_direct, *args, **kwargs) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Sync a related object (memberships) for Microsoft Entra provider.")) |  | ||||||
| def microsoft_entra_sync_m2m(*args, **kwargs): | def microsoft_entra_sync_m2m(*args, **kwargs): | ||||||
|     return sync_tasks.sync_signal_m2m(*args, **kwargs) |     return sync_tasks.sync_signal_m2m(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor( |  | ||||||
|     description=_( |  | ||||||
|         "Dispatch syncs for a related object (memberships) for Microsoft Entra providers." |  | ||||||
|     ) |  | ||||||
| ) |  | ||||||
| def microsoft_entra_sync_m2m_dispatch(*args, **kwargs): |  | ||||||
|     return sync_tasks.sync_signal_m2m_dispatch(microsoft_entra_sync_m2m, *args, **kwargs) |  | ||||||
|  | |||||||
| @ -252,13 +252,9 @@ class MicrosoftEntraGroupTests(TestCase): | |||||||
|             member_add.assert_called_once() |             member_add.assert_called_once() | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 member_add.call_args[0][0].odata_id, |                 member_add.call_args[0][0].odata_id, | ||||||
|                 f"https://graph.microsoft.com/v1.0/directoryObjects/{ |                 f"https://graph.microsoft.com/v1.0/directoryObjects/{MicrosoftEntraProviderUser.objects.filter( | ||||||
|                     MicrosoftEntraProviderUser.objects.filter( |  | ||||||
|                         provider=self.provider, |                         provider=self.provider, | ||||||
|                     ) |                     ).first().microsoft_id}", | ||||||
|                     .first() |  | ||||||
|                     .microsoft_id |  | ||||||
|                 }", |  | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def test_group_create_member_remove(self): |     def test_group_create_member_remove(self): | ||||||
| @ -315,13 +311,9 @@ class MicrosoftEntraGroupTests(TestCase): | |||||||
|             member_add.assert_called_once() |             member_add.assert_called_once() | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 member_add.call_args[0][0].odata_id, |                 member_add.call_args[0][0].odata_id, | ||||||
|                 f"https://graph.microsoft.com/v1.0/directoryObjects/{ |                 f"https://graph.microsoft.com/v1.0/directoryObjects/{MicrosoftEntraProviderUser.objects.filter( | ||||||
|                     MicrosoftEntraProviderUser.objects.filter( |  | ||||||
|                         provider=self.provider, |                         provider=self.provider, | ||||||
|                     ) |                     ).first().microsoft_id}", | ||||||
|                     .first() |  | ||||||
|                     .microsoft_id |  | ||||||
|                 }", |  | ||||||
|             ) |             ) | ||||||
|             member_remove.assert_called_once() |             member_remove.assert_called_once() | ||||||
|  |  | ||||||
| @ -421,7 +413,7 @@ class MicrosoftEntraGroupTests(TestCase): | |||||||
|                 ), |                 ), | ||||||
|             ) as group_list, |             ) as group_list, | ||||||
|         ): |         ): | ||||||
|             microsoft_entra_sync.send(self.provider.pk).get_result() |             microsoft_entra_sync.delay(self.provider.pk).get() | ||||||
|             self.assertTrue( |             self.assertTrue( | ||||||
|                 MicrosoftEntraProviderGroup.objects.filter( |                 MicrosoftEntraProviderGroup.objects.filter( | ||||||
|                     group=different_group, provider=self.provider |                     group=different_group, provider=self.provider | ||||||
|  | |||||||
| @ -397,7 +397,7 @@ class MicrosoftEntraUserTests(APITestCase): | |||||||
|                 AsyncMock(return_value=GroupCollectionResponse(value=[])), |                 AsyncMock(return_value=GroupCollectionResponse(value=[])), | ||||||
|             ), |             ), | ||||||
|         ): |         ): | ||||||
|             microsoft_entra_sync.send(self.provider.pk).get_result() |             microsoft_entra_sync.delay(self.provider.pk).get() | ||||||
|             self.assertTrue( |             self.assertTrue( | ||||||
|                 MicrosoftEntraProviderUser.objects.filter( |                 MicrosoftEntraProviderUser.objects.filter( | ||||||
|                     user=different_user, provider=self.provider |                     user=different_user, provider=self.provider | ||||||
|  | |||||||
| @ -17,7 +17,6 @@ from authentik.crypto.models import CertificateKeyPair | |||||||
| from authentik.lib.models import CreatedUpdatedModel | from authentik.lib.models import CreatedUpdatedModel | ||||||
| from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator | from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator | ||||||
| from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider | from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider | ||||||
| from authentik.tasks.models import TasksModel |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class EventTypes(models.TextChoices): | class EventTypes(models.TextChoices): | ||||||
| @ -43,7 +42,7 @@ class SSFEventStatus(models.TextChoices): | |||||||
|     SENT = "sent" |     SENT = "sent" | ||||||
|  |  | ||||||
|  |  | ||||||
| class SSFProvider(TasksModel, BackchannelProvider): | class SSFProvider(BackchannelProvider): | ||||||
|     """Shared Signals Framework provider to allow applications to |     """Shared Signals Framework provider to allow applications to | ||||||
|     receive user events from authentik.""" |     receive user events from authentik.""" | ||||||
|  |  | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ from authentik.enterprise.providers.ssf.models import ( | |||||||
|     EventTypes, |     EventTypes, | ||||||
|     SSFProvider, |     SSFProvider, | ||||||
| ) | ) | ||||||
| from authentik.enterprise.providers.ssf.tasks import send_ssf_events | from authentik.enterprise.providers.ssf.tasks import send_ssf_event | ||||||
| from authentik.events.middleware import audit_ignore | from authentik.events.middleware import audit_ignore | ||||||
| from authentik.stages.authenticator.models import Device | from authentik.stages.authenticator.models import Device | ||||||
| from authentik.stages.authenticator_duo.models import DuoDevice | from authentik.stages.authenticator_duo.models import DuoDevice | ||||||
| @ -66,7 +66,7 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi | |||||||
|  |  | ||||||
|     As this signal is also triggered with a regular logout, we can't be sure |     As this signal is also triggered with a regular logout, we can't be sure | ||||||
|     if the session has been deleted by an admin or by the user themselves.""" |     if the session has been deleted by an admin or by the user themselves.""" | ||||||
|     send_ssf_events( |     send_ssf_event( | ||||||
|         EventTypes.CAEP_SESSION_REVOKED, |         EventTypes.CAEP_SESSION_REVOKED, | ||||||
|         { |         { | ||||||
|             "initiating_entity": "user", |             "initiating_entity": "user", | ||||||
| @ -88,7 +88,7 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi | |||||||
| @receiver(password_changed) | @receiver(password_changed) | ||||||
| def ssf_password_changed_cred_change(sender, user: User, password: str | None, **_): | def ssf_password_changed_cred_change(sender, user: User, password: str | None, **_): | ||||||
|     """Credential change trigger (password changed)""" |     """Credential change trigger (password changed)""" | ||||||
|     send_ssf_events( |     send_ssf_event( | ||||||
|         EventTypes.CAEP_CREDENTIAL_CHANGE, |         EventTypes.CAEP_CREDENTIAL_CHANGE, | ||||||
|         { |         { | ||||||
|             "credential_type": "password", |             "credential_type": "password", | ||||||
| @ -126,7 +126,7 @@ def ssf_device_post_save(sender: type[Model], instance: Device, created: bool, * | |||||||
|     } |     } | ||||||
|     if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID: |     if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID: | ||||||
|         data["fido2_aaguid"] = instance.aaguid |         data["fido2_aaguid"] = instance.aaguid | ||||||
|     send_ssf_events( |     send_ssf_event( | ||||||
|         EventTypes.CAEP_CREDENTIAL_CHANGE, |         EventTypes.CAEP_CREDENTIAL_CHANGE, | ||||||
|         data, |         data, | ||||||
|         sub_id={ |         sub_id={ | ||||||
| @ -153,7 +153,7 @@ def ssf_device_post_delete(sender: type[Model], instance: Device, **_): | |||||||
|     } |     } | ||||||
|     if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID: |     if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID: | ||||||
|         data["fido2_aaguid"] = instance.aaguid |         data["fido2_aaguid"] = instance.aaguid | ||||||
|     send_ssf_events( |     send_ssf_event( | ||||||
|         EventTypes.CAEP_CREDENTIAL_CHANGE, |         EventTypes.CAEP_CREDENTIAL_CHANGE, | ||||||
|         data, |         data, | ||||||
|         sub_id={ |         sub_id={ | ||||||
|  | |||||||
| @ -1,11 +1,7 @@ | |||||||
| from typing import Any | from celery import group | ||||||
| from uuid import UUID |  | ||||||
|  |  | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from django_dramatiq_postgres.middleware import CurrentTask |  | ||||||
| from dramatiq.actor import actor |  | ||||||
| from requests.exceptions import RequestException | from requests.exceptions import RequestException | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| @ -17,16 +13,19 @@ from authentik.enterprise.providers.ssf.models import ( | |||||||
|     Stream, |     Stream, | ||||||
|     StreamEvent, |     StreamEvent, | ||||||
| ) | ) | ||||||
|  | from authentik.events.logs import LogEvent | ||||||
|  | from authentik.events.models import TaskStatus | ||||||
|  | from authentik.events.system_tasks import SystemTask | ||||||
| from authentik.lib.utils.http import get_http_session | from authentik.lib.utils.http import get_http_session | ||||||
| from authentik.lib.utils.time import timedelta_from_string | from authentik.lib.utils.time import timedelta_from_string | ||||||
| from authentik.policies.engine import PolicyEngine | from authentik.policies.engine import PolicyEngine | ||||||
| from authentik.tasks.models import Task | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| session = get_http_session() | session = get_http_session() | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| def send_ssf_events( | def send_ssf_event( | ||||||
|     event_type: EventTypes, |     event_type: EventTypes, | ||||||
|     data: dict, |     data: dict, | ||||||
|     stream_filter: dict | None = None, |     stream_filter: dict | None = None, | ||||||
| @ -34,7 +33,7 @@ def send_ssf_events( | |||||||
|     **extra_data, |     **extra_data, | ||||||
| ): | ): | ||||||
|     """Wrapper to send an SSF event to multiple streams""" |     """Wrapper to send an SSF event to multiple streams""" | ||||||
|     events_data = {} |     payload = [] | ||||||
|     if not stream_filter: |     if not stream_filter: | ||||||
|         stream_filter = {} |         stream_filter = {} | ||||||
|     stream_filter["events_requested__contains"] = [event_type] |     stream_filter["events_requested__contains"] = [event_type] | ||||||
| @ -42,22 +41,16 @@ def send_ssf_events( | |||||||
|         extra_data.setdefault("txn", request.request_id) |         extra_data.setdefault("txn", request.request_id) | ||||||
|     for stream in Stream.objects.filter(**stream_filter): |     for stream in Stream.objects.filter(**stream_filter): | ||||||
|         event_data = stream.prepare_event_payload(event_type, data, **extra_data) |         event_data = stream.prepare_event_payload(event_type, data, **extra_data) | ||||||
|         events_data[stream.uuid] = event_data |         payload.append((str(stream.uuid), event_data)) | ||||||
|     ssf_events_dispatch.send(events_data) |     return _send_ssf_event.delay(payload) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Dispatch SSF events.")) | def _check_app_access(stream_uuid: str, event_data: dict) -> bool: | ||||||
| def ssf_events_dispatch(events_data: dict[str, dict[str, Any]]): |  | ||||||
|     for stream_uuid, event_data in events_data.items(): |  | ||||||
|         stream = Stream.objects.filter(pk=stream_uuid).first() |  | ||||||
|         if not stream: |  | ||||||
|             continue |  | ||||||
|         send_ssf_event.send_with_options(args=(stream_uuid, event_data), rel_obj=stream.provider) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _check_app_access(stream: Stream, event_data: dict) -> bool: |  | ||||||
|     """Check if event is related to user and if so, check |     """Check if event is related to user and if so, check | ||||||
|     if the user has access to the application""" |     if the user has access to the application""" | ||||||
|  |     stream = Stream.objects.filter(pk=stream_uuid).first() | ||||||
|  |     if not stream: | ||||||
|  |         return False | ||||||
|     # `event_data` is a dict version of a StreamEvent |     # `event_data` is a dict version of a StreamEvent | ||||||
|     sub_id = event_data.get("payload", {}).get("sub_id", {}) |     sub_id = event_data.get("payload", {}).get("sub_id", {}) | ||||||
|     email = sub_id.get("user", {}).get("email", None) |     email = sub_id.get("user", {}).get("email", None) | ||||||
| @ -72,22 +65,42 @@ def _check_app_access(stream: Stream, event_data: dict) -> bool: | |||||||
|     return engine.passing |     return engine.passing | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Send an SSF event.")) | @CELERY_APP.task() | ||||||
| def send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]): | def _send_ssf_event(event_data: list[tuple[str, dict]]): | ||||||
|     self: Task = CurrentTask.get_task() |     tasks = [] | ||||||
|  |     for stream, data in event_data: | ||||||
|  |         if not _check_app_access(stream, data): | ||||||
|  |             continue | ||||||
|  |         event = StreamEvent.objects.create(**data) | ||||||
|  |         tasks.extend(send_single_ssf_event(stream, str(event.uuid))) | ||||||
|  |     main_task = group(*tasks) | ||||||
|  |     main_task() | ||||||
|  |  | ||||||
|     stream = Stream.objects.filter(pk=stream_uuid).first() |  | ||||||
|  | def send_single_ssf_event(stream_id: str, evt_id: str): | ||||||
|  |     stream = Stream.objects.filter(pk=stream_id).first() | ||||||
|     if not stream: |     if not stream: | ||||||
|         return |         return | ||||||
|     if not _check_app_access(stream, event_data): |     event = StreamEvent.objects.filter(pk=evt_id).first() | ||||||
|  |     if not event: | ||||||
|         return |         return | ||||||
|     event = StreamEvent.objects.create(**event_data) |  | ||||||
|     self.set_uid(event.pk) |  | ||||||
|     if event.status == SSFEventStatus.SENT: |     if event.status == SSFEventStatus.SENT: | ||||||
|         return |         return | ||||||
|     if stream.delivery_method != DeliveryMethods.RISC_PUSH: |     if stream.delivery_method == DeliveryMethods.RISC_PUSH: | ||||||
|         return |         return [ssf_push_event.si(str(event.pk))] | ||||||
|  |     return [] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @CELERY_APP.task(bind=True, base=SystemTask) | ||||||
|  | def ssf_push_event(self: SystemTask, event_id: str): | ||||||
|  |     self.save_on_success = False | ||||||
|  |     event = StreamEvent.objects.filter(pk=event_id).first() | ||||||
|  |     if not event: | ||||||
|  |         return | ||||||
|  |     self.set_uid(event_id) | ||||||
|  |     if event.status == SSFEventStatus.SENT: | ||||||
|  |         self.set_status(TaskStatus.SUCCESSFUL) | ||||||
|  |         return | ||||||
|     try: |     try: | ||||||
|         response = session.post( |         response = session.post( | ||||||
|             event.stream.endpoint_url, |             event.stream.endpoint_url, | ||||||
| @ -97,17 +110,26 @@ def send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]): | |||||||
|         response.raise_for_status() |         response.raise_for_status() | ||||||
|         event.status = SSFEventStatus.SENT |         event.status = SSFEventStatus.SENT | ||||||
|         event.save() |         event.save() | ||||||
|  |         self.set_status(TaskStatus.SUCCESSFUL) | ||||||
|         return |         return | ||||||
|     except RequestException as exc: |     except RequestException as exc: | ||||||
|         LOGGER.warning("Failed to send SSF event", exc=exc) |         LOGGER.warning("Failed to send SSF event", exc=exc) | ||||||
|  |         self.set_status(TaskStatus.ERROR) | ||||||
|         attrs = {} |         attrs = {} | ||||||
|         if exc.response: |         if exc.response: | ||||||
|             attrs["response"] = { |             attrs["response"] = { | ||||||
|                 "content": exc.response.text, |                 "content": exc.response.text, | ||||||
|                 "status": exc.response.status_code, |                 "status": exc.response.status_code, | ||||||
|             } |             } | ||||||
|         self.warning(exc) |         self.set_error( | ||||||
|         self.warning("Failed to send request", **attrs) |             exc, | ||||||
|  |             LogEvent( | ||||||
|  |                 _("Failed to send request"), | ||||||
|  |                 log_level="warning", | ||||||
|  |                 logger=self.__name__, | ||||||
|  |                 attributes=attrs, | ||||||
|  |             ), | ||||||
|  |         ) | ||||||
|         # Re-up the expiry of the stream event |         # Re-up the expiry of the stream event | ||||||
|         event.expires = now() + timedelta_from_string(event.stream.provider.event_retention) |         event.expires = now() + timedelta_from_string(event.stream.provider.event_retention) | ||||||
|         event.status = SSFEventStatus.PENDING_FAILED |         event.status = SSFEventStatus.PENDING_FAILED | ||||||
|  | |||||||
| @ -13,7 +13,7 @@ from authentik.enterprise.providers.ssf.models import ( | |||||||
|     SSFProvider, |     SSFProvider, | ||||||
|     Stream, |     Stream, | ||||||
| ) | ) | ||||||
| from authentik.enterprise.providers.ssf.tasks import send_ssf_events | from authentik.enterprise.providers.ssf.tasks import send_ssf_event | ||||||
| from authentik.enterprise.providers.ssf.views.base import SSFView | from authentik.enterprise.providers.ssf.views.base import SSFView | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -109,7 +109,7 @@ class StreamView(SSFView): | |||||||
|                 "User does not have permission to create stream for this provider." |                 "User does not have permission to create stream for this provider." | ||||||
|             ) |             ) | ||||||
|         instance: Stream = stream.save(provider=self.provider) |         instance: Stream = stream.save(provider=self.provider) | ||||||
|         send_ssf_events( |         send_ssf_event( | ||||||
|             EventTypes.SET_VERIFICATION, |             EventTypes.SET_VERIFICATION, | ||||||
|             { |             { | ||||||
|                 "state": None, |                 "state": None, | ||||||
|  | |||||||
| @ -1,5 +1,17 @@ | |||||||
| """Enterprise additional settings""" | """Enterprise additional settings""" | ||||||
|  |  | ||||||
|  | from celery.schedules import crontab | ||||||
|  |  | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
|  |  | ||||||
|  | CELERY_BEAT_SCHEDULE = { | ||||||
|  |     "enterprise_update_usage": { | ||||||
|  |         "task": "authentik.enterprise.tasks.enterprise_update_usage", | ||||||
|  |         "schedule": crontab(minute=fqdn_rand("enterprise_update_usage"), hour="*/2"), | ||||||
|  |         "options": {"queue": "authentik_scheduled"}, | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| TENANT_APPS = [ | TENANT_APPS = [ | ||||||
|     "authentik.enterprise.audit", |     "authentik.enterprise.audit", | ||||||
|     "authentik.enterprise.policies.unique_password", |     "authentik.enterprise.policies.unique_password", | ||||||
|  | |||||||
| @ -10,7 +10,6 @@ from django.utils.timezone import get_current_timezone | |||||||
| from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE | from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE | ||||||
| from authentik.enterprise.models import License | from authentik.enterprise.models import License | ||||||
| from authentik.enterprise.tasks import enterprise_update_usage | from authentik.enterprise.tasks import enterprise_update_usage | ||||||
| from authentik.tasks.schedules.models import Schedule |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_save, sender=License) | @receiver(pre_save, sender=License) | ||||||
| @ -27,7 +26,7 @@ def pre_save_license(sender: type[License], instance: License, **_): | |||||||
| def post_save_license(sender: type[License], instance: License, **_): | def post_save_license(sender: type[License], instance: License, **_): | ||||||
|     """Trigger license usage calculation when license is saved""" |     """Trigger license usage calculation when license is saved""" | ||||||
|     cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) |     cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) | ||||||
|     Schedule.dispatch_by_actor(enterprise_update_usage) |     enterprise_update_usage.delay() | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(post_delete, sender=License) | @receiver(post_delete, sender=License) | ||||||
|  | |||||||
| @ -1,11 +1,14 @@ | |||||||
| """Enterprise tasks""" | """Enterprise tasks""" | ||||||
|  |  | ||||||
| from django.utils.translation import gettext_lazy as _ |  | ||||||
| from dramatiq.actor import actor |  | ||||||
|  |  | ||||||
| from authentik.enterprise.license import LicenseKey | from authentik.enterprise.license import LicenseKey | ||||||
|  | from authentik.events.models import TaskStatus | ||||||
|  | from authentik.events.system_tasks import SystemTask, prefill_task | ||||||
|  | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Update enterprise license status.")) | @CELERY_APP.task(bind=True, base=SystemTask) | ||||||
| def enterprise_update_usage(): | @prefill_task | ||||||
|  | def enterprise_update_usage(self: SystemTask): | ||||||
|  |     """Update enterprise license status""" | ||||||
|     LicenseKey.get_total().record_usage() |     LicenseKey.get_total().record_usage() | ||||||
|  |     self.set_status(TaskStatus.SUCCESSFUL) | ||||||
|  | |||||||
							
								
								
									
										104
									
								
								authentik/events/api/tasks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								authentik/events/api/tasks.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,104 @@ | |||||||
|  | """Tasks API""" | ||||||
|  |  | ||||||
|  | from importlib import import_module | ||||||
|  |  | ||||||
|  | from django.contrib import messages | ||||||
|  | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from drf_spectacular.types import OpenApiTypes | ||||||
|  | from drf_spectacular.utils import OpenApiResponse, extend_schema | ||||||
|  | from rest_framework.decorators import action | ||||||
|  | from rest_framework.fields import ( | ||||||
|  |     CharField, | ||||||
|  |     ChoiceField, | ||||||
|  |     DateTimeField, | ||||||
|  |     FloatField, | ||||||
|  |     SerializerMethodField, | ||||||
|  | ) | ||||||
|  | from rest_framework.request import Request | ||||||
|  | from rest_framework.response import Response | ||||||
|  | from rest_framework.viewsets import ReadOnlyModelViewSet | ||||||
|  | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
|  | from authentik.core.api.utils import ModelSerializer | ||||||
|  | from authentik.events.logs import LogEventSerializer | ||||||
|  | from authentik.events.models import SystemTask, TaskStatus | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
|  | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SystemTaskSerializer(ModelSerializer): | ||||||
|  |     """Serialize TaskInfo and TaskResult""" | ||||||
|  |  | ||||||
|  |     name = CharField() | ||||||
|  |     full_name = SerializerMethodField() | ||||||
|  |     uid = CharField(required=False) | ||||||
|  |     description = CharField() | ||||||
|  |     start_timestamp = DateTimeField(read_only=True) | ||||||
|  |     finish_timestamp = DateTimeField(read_only=True) | ||||||
|  |     duration = FloatField(read_only=True) | ||||||
|  |  | ||||||
|  |     status = ChoiceField(choices=[(x.value, x.name) for x in TaskStatus]) | ||||||
|  |     messages = LogEventSerializer(many=True) | ||||||
|  |  | ||||||
|  |     def get_full_name(self, instance: SystemTask) -> str: | ||||||
|  |         """Get full name with UID""" | ||||||
|  |         if instance.uid: | ||||||
|  |             return f"{instance.name}:{instance.uid}" | ||||||
|  |         return instance.name | ||||||
|  |  | ||||||
|  |     class Meta: | ||||||
|  |         model = SystemTask | ||||||
|  |         fields = [ | ||||||
|  |             "uuid", | ||||||
|  |             "name", | ||||||
|  |             "full_name", | ||||||
|  |             "uid", | ||||||
|  |             "description", | ||||||
|  |             "start_timestamp", | ||||||
|  |             "finish_timestamp", | ||||||
|  |             "duration", | ||||||
|  |             "status", | ||||||
|  |             "messages", | ||||||
|  |             "expires", | ||||||
|  |             "expiring", | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SystemTaskViewSet(ReadOnlyModelViewSet): | ||||||
|  |     """Read-only view set that returns all background tasks""" | ||||||
|  |  | ||||||
|  |     queryset = SystemTask.objects.all() | ||||||
|  |     serializer_class = SystemTaskSerializer | ||||||
|  |     filterset_fields = ["name", "uid", "status"] | ||||||
|  |     ordering = ["name", "uid", "status"] | ||||||
|  |     search_fields = ["name", "description", "uid", "status"] | ||||||
|  |  | ||||||
|  |     @permission_required(None, ["authentik_events.run_task"]) | ||||||
|  |     @extend_schema( | ||||||
|  |         request=OpenApiTypes.NONE, | ||||||
|  |         responses={ | ||||||
|  |             204: OpenApiResponse(description="Task retried successfully"), | ||||||
|  |             404: OpenApiResponse(description="Task not found"), | ||||||
|  |             500: OpenApiResponse(description="Failed to retry task"), | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     @action(detail=True, methods=["POST"], permission_classes=[]) | ||||||
|  |     def run(self, request: Request, pk=None) -> Response: | ||||||
|  |         """Run task""" | ||||||
|  |         task: SystemTask = self.get_object() | ||||||
|  |         try: | ||||||
|  |             task_module = import_module(task.task_call_module) | ||||||
|  |             task_func = getattr(task_module, task.task_call_func) | ||||||
|  |             LOGGER.info("Running task", task=task_func) | ||||||
|  |             task_func.delay(*task.task_call_args, **task.task_call_kwargs) | ||||||
|  |             messages.success( | ||||||
|  |                 self.request, | ||||||
|  |                 _("Successfully started task {name}.".format_map({"name": task.name})), | ||||||
|  |             ) | ||||||
|  |             return Response(status=204) | ||||||
|  |         except (ImportError, AttributeError) as exc:  # pragma: no cover | ||||||
|  |             LOGGER.warning("Failed to run task, remove state", task=task.name, exc=exc) | ||||||
|  |             # if we get an import error, the module path has probably changed | ||||||
|  |             task.delete() | ||||||
|  |             return Response(status=500) | ||||||
| @ -1,11 +1,12 @@ | |||||||
| """authentik events app""" | """authentik events app""" | ||||||
|  |  | ||||||
|  | from celery.schedules import crontab | ||||||
| from prometheus_client import Gauge, Histogram | from prometheus_client import Gauge, Histogram | ||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
| from authentik.lib.config import CONFIG, ENV_PREFIX | from authentik.lib.config import CONFIG, ENV_PREFIX | ||||||
| from authentik.lib.utils.time import fqdn_rand | from authentik.lib.utils.reflection import path_to_class | ||||||
| from authentik.tasks.schedules.lib import ScheduleSpec | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| # TODO: Deprecated metric - remove in 2024.2 or later | # TODO: Deprecated metric - remove in 2024.2 or later | ||||||
| GAUGE_TASKS = Gauge( | GAUGE_TASKS = Gauge( | ||||||
| @ -34,17 +35,6 @@ class AuthentikEventsConfig(ManagedAppConfig): | |||||||
|     verbose_name = "authentik Events" |     verbose_name = "authentik Events" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def tenant_schedule_specs(self) -> list[ScheduleSpec]: |  | ||||||
|         from authentik.events.tasks import notification_cleanup |  | ||||||
|  |  | ||||||
|         return [ |  | ||||||
|             ScheduleSpec( |  | ||||||
|                 actor=notification_cleanup, |  | ||||||
|                 crontab=f"{fqdn_rand('notification_cleanup')} */8 * * *", |  | ||||||
|             ), |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|     @ManagedAppConfig.reconcile_global |     @ManagedAppConfig.reconcile_global | ||||||
|     def check_deprecations(self): |     def check_deprecations(self): | ||||||
|         """Check for config deprecations""" |         """Check for config deprecations""" | ||||||
| @ -66,3 +56,41 @@ class AuthentikEventsConfig(ManagedAppConfig): | |||||||
|                 replacement_env=replace_env, |                 replacement_env=replace_env, | ||||||
|                 message=msg, |                 message=msg, | ||||||
|             ).save() |             ).save() | ||||||
|  |  | ||||||
|  |     @ManagedAppConfig.reconcile_tenant | ||||||
|  |     def prefill_tasks(self): | ||||||
|  |         """Prefill tasks""" | ||||||
|  |         from authentik.events.models import SystemTask | ||||||
|  |         from authentik.events.system_tasks import _prefill_tasks | ||||||
|  |  | ||||||
|  |         for task in _prefill_tasks: | ||||||
|  |             if SystemTask.objects.filter(name=task.name).exists(): | ||||||
|  |                 continue | ||||||
|  |             task.save() | ||||||
|  |             self.logger.debug("prefilled task", task_name=task.name) | ||||||
|  |  | ||||||
|  |     @ManagedAppConfig.reconcile_tenant | ||||||
|  |     def run_scheduled_tasks(self): | ||||||
|  |         """Run schedule tasks which are behind schedule (only applies | ||||||
|  |         to tasks of which we keep metrics)""" | ||||||
|  |         from authentik.events.models import TaskStatus | ||||||
|  |         from authentik.events.system_tasks import SystemTask as CelerySystemTask | ||||||
|  |  | ||||||
|  |         for task in CELERY_APP.conf["beat_schedule"].values(): | ||||||
|  |             schedule = task["schedule"] | ||||||
|  |             if not isinstance(schedule, crontab): | ||||||
|  |                 continue | ||||||
|  |             task_class: CelerySystemTask = path_to_class(task["task"]) | ||||||
|  |             if not isinstance(task_class, CelerySystemTask): | ||||||
|  |                 continue | ||||||
|  |             db_task = task_class.db() | ||||||
|  |             if not db_task: | ||||||
|  |                 continue | ||||||
|  |             due, _ = schedule.is_due(db_task.finish_timestamp) | ||||||
|  |             if due or db_task.status == TaskStatus.UNKNOWN: | ||||||
|  |                 self.logger.debug("Running past-due scheduled task", task=task["task"]) | ||||||
|  |                 task_class.apply_async( | ||||||
|  |                     args=task.get("args", None), | ||||||
|  |                     kwargs=task.get("kwargs", None), | ||||||
|  |                     **task.get("options", {}), | ||||||
|  |                 ) | ||||||
|  | |||||||
| @ -1,22 +0,0 @@ | |||||||
| # Generated by Django 5.1.11 on 2025-06-24 15:36 |  | ||||||
|  |  | ||||||
| from django.db import migrations |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ("authentik_events", "0010_rename_group_notificationrule_destination_group_and_more"), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.AlterModelOptions( |  | ||||||
|             name="systemtask", |  | ||||||
|             options={ |  | ||||||
|                 "default_permissions": (), |  | ||||||
|                 "permissions": (), |  | ||||||
|                 "verbose_name": "System Task", |  | ||||||
|                 "verbose_name_plural": "System Tasks", |  | ||||||
|             }, |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -5,11 +5,12 @@ from datetime import timedelta | |||||||
| from difflib import get_close_matches | from difflib import get_close_matches | ||||||
| from functools import lru_cache | from functools import lru_cache | ||||||
| from inspect import currentframe | from inspect import currentframe | ||||||
|  | from smtplib import SMTPException | ||||||
| from typing import Any | from typing import Any | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from django.apps import apps | from django.apps import apps | ||||||
| from django.db import models | from django.db import connection, models | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| from django.http.request import QueryDict | from django.http.request import QueryDict | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| @ -26,6 +27,7 @@ from authentik.core.middleware import ( | |||||||
|     SESSION_KEY_IMPERSONATE_USER, |     SESSION_KEY_IMPERSONATE_USER, | ||||||
| ) | ) | ||||||
| from authentik.core.models import ExpiringModel, Group, PropertyMapping, User | from authentik.core.models import ExpiringModel, Group, PropertyMapping, User | ||||||
|  | from authentik.events.apps import GAUGE_TASKS, SYSTEM_TASK_STATUS, SYSTEM_TASK_TIME | ||||||
| from authentik.events.context_processors.base import get_context_processors | from authentik.events.context_processors.base import get_context_processors | ||||||
| from authentik.events.utils import ( | from authentik.events.utils import ( | ||||||
|     cleanse_dict, |     cleanse_dict, | ||||||
| @ -41,7 +43,6 @@ from authentik.lib.utils.time import timedelta_from_string | |||||||
| from authentik.policies.models import PolicyBindingModel | from authentik.policies.models import PolicyBindingModel | ||||||
| from authentik.root.middleware import ClientIPMiddleware | from authentik.root.middleware import ClientIPMiddleware | ||||||
| from authentik.stages.email.utils import TemplateEmailMessage | from authentik.stages.email.utils import TemplateEmailMessage | ||||||
| from authentik.tasks.models import TasksModel |  | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
| from authentik.tenants.utils import get_current_tenant | from authentik.tenants.utils import get_current_tenant | ||||||
|  |  | ||||||
| @ -266,8 +267,7 @@ class Event(SerializerModel, ExpiringModel): | |||||||
|             models.Index(fields=["created"]), |             models.Index(fields=["created"]), | ||||||
|             models.Index(fields=["client_ip"]), |             models.Index(fields=["client_ip"]), | ||||||
|             models.Index( |             models.Index( | ||||||
|                 models.F("context__authorized_application"), |                 models.F("context__authorized_application"), name="authentik_e_ctx_app__idx" | ||||||
|                 name="authentik_e_ctx_app__idx", |  | ||||||
|             ), |             ), | ||||||
|         ] |         ] | ||||||
|  |  | ||||||
| @ -281,7 +281,7 @@ class TransportMode(models.TextChoices): | |||||||
|     EMAIL = "email", _("Email") |     EMAIL = "email", _("Email") | ||||||
|  |  | ||||||
|  |  | ||||||
| class NotificationTransport(TasksModel, SerializerModel): | class NotificationTransport(SerializerModel): | ||||||
|     """Action which is executed when a Rule matches""" |     """Action which is executed when a Rule matches""" | ||||||
|  |  | ||||||
|     uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) |     uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) | ||||||
| @ -446,8 +446,6 @@ class NotificationTransport(TasksModel, SerializerModel): | |||||||
|  |  | ||||||
|     def send_email(self, notification: "Notification") -> list[str]: |     def send_email(self, notification: "Notification") -> list[str]: | ||||||
|         """Send notification via global email configuration""" |         """Send notification via global email configuration""" | ||||||
|         from authentik.stages.email.tasks import send_mail |  | ||||||
|  |  | ||||||
|         if notification.user.email.strip() == "": |         if notification.user.email.strip() == "": | ||||||
|             LOGGER.info( |             LOGGER.info( | ||||||
|                 "Discarding notification as user has no email address", |                 "Discarding notification as user has no email address", | ||||||
| @ -489,14 +487,17 @@ class NotificationTransport(TasksModel, SerializerModel): | |||||||
|             template_name="email/event_notification.html", |             template_name="email/event_notification.html", | ||||||
|             template_context=context, |             template_context=context, | ||||||
|         ) |         ) | ||||||
|         send_mail.send_with_options(args=(mail.__dict__,), rel_obj=self) |         # Email is sent directly here, as the call to send() should have been from a task. | ||||||
|         return [] |         try: | ||||||
|  |             from authentik.stages.email.tasks import send_mail | ||||||
|  |  | ||||||
|  |             return send_mail(mail.__dict__) | ||||||
|  |         except (SMTPException, ConnectionError, OSError) as exc: | ||||||
|  |             raise NotificationTransportError(exc) from exc | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> type[Serializer]: |     def serializer(self) -> type[Serializer]: | ||||||
|         from authentik.events.api.notification_transports import ( |         from authentik.events.api.notification_transports import NotificationTransportSerializer | ||||||
|             NotificationTransportSerializer, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         return NotificationTransportSerializer |         return NotificationTransportSerializer | ||||||
|  |  | ||||||
| @ -546,7 +547,7 @@ class Notification(SerializerModel): | |||||||
|         verbose_name_plural = _("Notifications") |         verbose_name_plural = _("Notifications") | ||||||
|  |  | ||||||
|  |  | ||||||
| class NotificationRule(TasksModel, SerializerModel, PolicyBindingModel): | class NotificationRule(SerializerModel, PolicyBindingModel): | ||||||
|     """Decide when to create a Notification based on policies attached to this object.""" |     """Decide when to create a Notification based on policies attached to this object.""" | ||||||
|  |  | ||||||
|     name = models.TextField(unique=True) |     name = models.TextField(unique=True) | ||||||
| @ -610,9 +611,7 @@ class NotificationWebhookMapping(PropertyMapping): | |||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> type[type[Serializer]]: |     def serializer(self) -> type[type[Serializer]]: | ||||||
|         from authentik.events.api.notification_mappings import ( |         from authentik.events.api.notification_mappings import NotificationWebhookMappingSerializer | ||||||
|             NotificationWebhookMappingSerializer, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         return NotificationWebhookMappingSerializer |         return NotificationWebhookMappingSerializer | ||||||
|  |  | ||||||
| @ -625,7 +624,7 @@ class NotificationWebhookMapping(PropertyMapping): | |||||||
|  |  | ||||||
|  |  | ||||||
| class TaskStatus(models.TextChoices): | class TaskStatus(models.TextChoices): | ||||||
|     """DEPRECATED do not use""" |     """Possible states of tasks""" | ||||||
|  |  | ||||||
|     UNKNOWN = "unknown" |     UNKNOWN = "unknown" | ||||||
|     SUCCESSFUL = "successful" |     SUCCESSFUL = "successful" | ||||||
| @ -633,8 +632,8 @@ class TaskStatus(models.TextChoices): | |||||||
|     ERROR = "error" |     ERROR = "error" | ||||||
|  |  | ||||||
|  |  | ||||||
| class SystemTask(ExpiringModel): | class SystemTask(SerializerModel, ExpiringModel): | ||||||
|     """DEPRECATED do not use""" |     """Info about a system task running in the background along with details to restart the task""" | ||||||
|  |  | ||||||
|     uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) |     uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) | ||||||
|     name = models.TextField() |     name = models.TextField() | ||||||
| @ -654,13 +653,41 @@ class SystemTask(ExpiringModel): | |||||||
|     task_call_args = models.JSONField(default=list) |     task_call_args = models.JSONField(default=list) | ||||||
|     task_call_kwargs = models.JSONField(default=dict) |     task_call_kwargs = models.JSONField(default=dict) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def serializer(self) -> type[Serializer]: | ||||||
|  |         from authentik.events.api.tasks import SystemTaskSerializer | ||||||
|  |  | ||||||
|  |         return SystemTaskSerializer | ||||||
|  |  | ||||||
|  |     def update_metrics(self): | ||||||
|  |         """Update prometheus metrics""" | ||||||
|  |         # TODO: Deprecated metric - remove in 2024.2 or later | ||||||
|  |         GAUGE_TASKS.labels( | ||||||
|  |             tenant=connection.schema_name, | ||||||
|  |             task_name=self.name, | ||||||
|  |             task_uid=self.uid or "", | ||||||
|  |             status=self.status.lower(), | ||||||
|  |         ).set(self.duration) | ||||||
|  |         SYSTEM_TASK_TIME.labels( | ||||||
|  |             tenant=connection.schema_name, | ||||||
|  |             task_name=self.name, | ||||||
|  |             task_uid=self.uid or "", | ||||||
|  |         ).observe(self.duration) | ||||||
|  |         SYSTEM_TASK_STATUS.labels( | ||||||
|  |             tenant=connection.schema_name, | ||||||
|  |             task_name=self.name, | ||||||
|  |             task_uid=self.uid or "", | ||||||
|  |             status=self.status.lower(), | ||||||
|  |         ).inc() | ||||||
|  |  | ||||||
|     def __str__(self) -> str: |     def __str__(self) -> str: | ||||||
|         return f"System Task {self.name}" |         return f"System Task {self.name}" | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         unique_together = (("name", "uid"),) |         unique_together = (("name", "uid"),) | ||||||
|         default_permissions = () |         # Remove "add", "change" and "delete" permissions as those are not used | ||||||
|         permissions = () |         default_permissions = ["view"] | ||||||
|  |         permissions = [("run_task", _("Run task"))] | ||||||
|         verbose_name = _("System Task") |         verbose_name = _("System Task") | ||||||
|         verbose_name_plural = _("System Tasks") |         verbose_name_plural = _("System Tasks") | ||||||
|         indexes = ExpiringModel.Meta.indexes |         indexes = ExpiringModel.Meta.indexes | ||||||
|  | |||||||
							
								
								
									
										13
									
								
								authentik/events/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								authentik/events/settings.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,13 @@ | |||||||
|  | """Event Settings""" | ||||||
|  |  | ||||||
|  | from celery.schedules import crontab | ||||||
|  |  | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
|  |  | ||||||
|  | CELERY_BEAT_SCHEDULE = { | ||||||
|  |     "events_notification_cleanup": { | ||||||
|  |         "task": "authentik.events.tasks.notification_cleanup", | ||||||
|  |         "schedule": crontab(minute=fqdn_rand("notification_cleanup"), hour="*/8"), | ||||||
|  |         "options": {"queue": "authentik_scheduled"}, | ||||||
|  |     }, | ||||||
|  | } | ||||||
| @ -12,10 +12,13 @@ from rest_framework.request import Request | |||||||
|  |  | ||||||
| from authentik.core.models import AuthenticatedSession, User | from authentik.core.models import AuthenticatedSession, User | ||||||
| from authentik.core.signals import login_failed, password_changed | from authentik.core.signals import login_failed, password_changed | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.apps import SYSTEM_TASK_STATUS | ||||||
|  | from authentik.events.models import Event, EventAction, SystemTask | ||||||
|  | from authentik.events.tasks import event_notification_handler, gdpr_cleanup | ||||||
| from authentik.flows.models import Stage | from authentik.flows.models import Stage | ||||||
| from authentik.flows.planner import PLAN_CONTEXT_OUTPOST, PLAN_CONTEXT_SOURCE, FlowPlan | from authentik.flows.planner import PLAN_CONTEXT_OUTPOST, PLAN_CONTEXT_SOURCE, FlowPlan | ||||||
| from authentik.flows.views.executor import SESSION_KEY_PLAN | from authentik.flows.views.executor import SESSION_KEY_PLAN | ||||||
|  | from authentik.root.monitoring import monitoring_set | ||||||
| from authentik.stages.invitation.models import Invitation | from authentik.stages.invitation.models import Invitation | ||||||
| from authentik.stages.invitation.signals import invitation_used | from authentik.stages.invitation.signals import invitation_used | ||||||
| from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS | from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS | ||||||
| @ -111,15 +114,19 @@ def on_password_changed(sender, user: User, password: str, request: HttpRequest | |||||||
| @receiver(post_save, sender=Event) | @receiver(post_save, sender=Event) | ||||||
| def event_post_save_notification(sender, instance: Event, **_): | def event_post_save_notification(sender, instance: Event, **_): | ||||||
|     """Start task to check if any policies trigger an notification on this event""" |     """Start task to check if any policies trigger an notification on this event""" | ||||||
|     from authentik.events.tasks import event_trigger_dispatch |     event_notification_handler.delay(instance.event_uuid.hex) | ||||||
|  |  | ||||||
|     event_trigger_dispatch.send(instance.event_uuid) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_delete, sender=User) | @receiver(pre_delete, sender=User) | ||||||
| def event_user_pre_delete_cleanup(sender, instance: User, **_): | def event_user_pre_delete_cleanup(sender, instance: User, **_): | ||||||
|     """If gdpr_compliance is enabled, remove all the user's events""" |     """If gdpr_compliance is enabled, remove all the user's events""" | ||||||
|     from authentik.events.tasks import gdpr_cleanup |  | ||||||
|  |  | ||||||
|     if get_current_tenant().gdpr_compliance: |     if get_current_tenant().gdpr_compliance: | ||||||
|         gdpr_cleanup.send(instance.pk) |         gdpr_cleanup.delay(instance.pk) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(monitoring_set) | ||||||
|  | def monitoring_system_task(sender, **_): | ||||||
|  |     """Update metrics when task is saved""" | ||||||
|  |     SYSTEM_TASK_STATUS.clear() | ||||||
|  |     for task in SystemTask.objects.all(): | ||||||
|  |         task.update_metrics() | ||||||
|  | |||||||
							
								
								
									
										156
									
								
								authentik/events/system_tasks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										156
									
								
								authentik/events/system_tasks.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,156 @@ | |||||||
|  | """Monitored tasks""" | ||||||
|  |  | ||||||
|  | from datetime import datetime, timedelta | ||||||
|  | from time import perf_counter | ||||||
|  | from typing import Any | ||||||
|  |  | ||||||
|  | from django.utils.timezone import now | ||||||
|  | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from structlog.stdlib import BoundLogger, get_logger | ||||||
|  | from tenant_schemas_celery.task import TenantTask | ||||||
|  |  | ||||||
|  | from authentik.events.logs import LogEvent | ||||||
|  | from authentik.events.models import Event, EventAction, TaskStatus | ||||||
|  | from authentik.events.models import SystemTask as DBSystemTask | ||||||
|  | from authentik.events.utils import sanitize_item | ||||||
|  | from authentik.lib.utils.errors import exception_to_string | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SystemTask(TenantTask): | ||||||
|  |     """Task which can save its state to the cache""" | ||||||
|  |  | ||||||
|  |     logger: BoundLogger | ||||||
|  |  | ||||||
|  |     # For tasks that should only be listed if they failed, set this to False | ||||||
|  |     save_on_success: bool | ||||||
|  |  | ||||||
|  |     _status: TaskStatus | ||||||
|  |     _messages: list[LogEvent] | ||||||
|  |  | ||||||
|  |     _uid: str | None | ||||||
|  |     # Precise start time from perf_counter | ||||||
|  |     _start_precise: float | None = None | ||||||
|  |     _start: datetime | None = None | ||||||
|  |  | ||||||
|  |     def __init__(self, *args, **kwargs) -> None: | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  |         self._status = TaskStatus.SUCCESSFUL | ||||||
|  |         self.save_on_success = True | ||||||
|  |         self._uid = None | ||||||
|  |         self._status = None | ||||||
|  |         self._messages = [] | ||||||
|  |         self.result_timeout_hours = 6 | ||||||
|  |  | ||||||
|  |     def set_uid(self, uid: str): | ||||||
|  |         """Set UID, so in the case of an unexpected error its saved correctly""" | ||||||
|  |         self._uid = uid | ||||||
|  |  | ||||||
|  |     def set_status(self, status: TaskStatus, *messages: LogEvent): | ||||||
|  |         """Set result for current run, will overwrite previous result.""" | ||||||
|  |         self._status = status | ||||||
|  |         self._messages = list(messages) | ||||||
|  |         for idx, msg in enumerate(self._messages): | ||||||
|  |             if not isinstance(msg, LogEvent): | ||||||
|  |                 self._messages[idx] = LogEvent(msg, logger=self.__name__, log_level="info") | ||||||
|  |  | ||||||
|  |     def set_error(self, exception: Exception, *messages: LogEvent): | ||||||
|  |         """Set result to error and save exception""" | ||||||
|  |         self._status = TaskStatus.ERROR | ||||||
|  |         self._messages = list(messages) | ||||||
|  |         self._messages.extend( | ||||||
|  |             [LogEvent(exception_to_string(exception), logger=self.__name__, log_level="error")] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def before_start(self, task_id, args, kwargs): | ||||||
|  |         self._start_precise = perf_counter() | ||||||
|  |         self._start = now() | ||||||
|  |         self.logger = get_logger().bind(task_id=task_id) | ||||||
|  |         return super().before_start(task_id, args, kwargs) | ||||||
|  |  | ||||||
|  |     def db(self) -> DBSystemTask | None: | ||||||
|  |         """Get DB object for latest task""" | ||||||
|  |         return DBSystemTask.objects.filter( | ||||||
|  |             name=self.__name__, | ||||||
|  |             uid=self._uid, | ||||||
|  |         ).first() | ||||||
|  |  | ||||||
|  |     def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo): | ||||||
|  |         super().after_return(status, retval, task_id, args, kwargs, einfo=einfo) | ||||||
|  |         if not self._status: | ||||||
|  |             return | ||||||
|  |         if self._status == TaskStatus.SUCCESSFUL and not self.save_on_success: | ||||||
|  |             DBSystemTask.objects.filter( | ||||||
|  |                 name=self.__name__, | ||||||
|  |                 uid=self._uid, | ||||||
|  |             ).delete() | ||||||
|  |             return | ||||||
|  |         DBSystemTask.objects.update_or_create( | ||||||
|  |             name=self.__name__, | ||||||
|  |             uid=self._uid, | ||||||
|  |             defaults={ | ||||||
|  |                 "description": self.__doc__, | ||||||
|  |                 "start_timestamp": self._start or now(), | ||||||
|  |                 "finish_timestamp": now(), | ||||||
|  |                 "duration": max(perf_counter() - self._start_precise, 0), | ||||||
|  |                 "task_call_module": self.__module__, | ||||||
|  |                 "task_call_func": self.__name__, | ||||||
|  |                 "task_call_args": sanitize_item(args), | ||||||
|  |                 "task_call_kwargs": sanitize_item(kwargs), | ||||||
|  |                 "status": self._status, | ||||||
|  |                 "messages": sanitize_item(self._messages), | ||||||
|  |                 "expires": now() + timedelta(hours=self.result_timeout_hours), | ||||||
|  |                 "expiring": True, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def on_failure(self, exc, task_id, args, kwargs, einfo): | ||||||
|  |         super().on_failure(exc, task_id, args, kwargs, einfo=einfo) | ||||||
|  |         if not self._status: | ||||||
|  |             self.set_error(exc) | ||||||
|  |         DBSystemTask.objects.update_or_create( | ||||||
|  |             name=self.__name__, | ||||||
|  |             uid=self._uid, | ||||||
|  |             defaults={ | ||||||
|  |                 "description": self.__doc__, | ||||||
|  |                 "start_timestamp": self._start or now(), | ||||||
|  |                 "finish_timestamp": now(), | ||||||
|  |                 "duration": max(perf_counter() - self._start_precise, 0), | ||||||
|  |                 "task_call_module": self.__module__, | ||||||
|  |                 "task_call_func": self.__name__, | ||||||
|  |                 "task_call_args": sanitize_item(args), | ||||||
|  |                 "task_call_kwargs": sanitize_item(kwargs), | ||||||
|  |                 "status": self._status, | ||||||
|  |                 "messages": sanitize_item(self._messages), | ||||||
|  |                 "expires": now() + timedelta(hours=self.result_timeout_hours + 3), | ||||||
|  |                 "expiring": True, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |         Event.new( | ||||||
|  |             EventAction.SYSTEM_TASK_EXCEPTION, | ||||||
|  |             message=f"Task {self.__name__} encountered an error: {exception_to_string(exc)}", | ||||||
|  |         ).save() | ||||||
|  |  | ||||||
|  |     def run(self, *args, **kwargs): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def prefill_task(func): | ||||||
|  |     """Ensure a task's details are always in cache, so it can always be triggered via API""" | ||||||
|  |     _prefill_tasks.append( | ||||||
|  |         DBSystemTask( | ||||||
|  |             name=func.__name__, | ||||||
|  |             description=func.__doc__, | ||||||
|  |             start_timestamp=now(), | ||||||
|  |             finish_timestamp=now(), | ||||||
|  |             status=TaskStatus.UNKNOWN, | ||||||
|  |             messages=sanitize_item([_("Task has not been run yet.")]), | ||||||
|  |             task_call_module=func.__module__, | ||||||
|  |             task_call_func=func.__name__, | ||||||
|  |             expiring=False, | ||||||
|  |             duration=0, | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |     return func | ||||||
|  |  | ||||||
|  |  | ||||||
|  | _prefill_tasks = [] | ||||||
| @ -1,49 +1,41 @@ | |||||||
| """Event notification tasks""" | """Event notification tasks""" | ||||||
|  |  | ||||||
| from uuid import UUID |  | ||||||
|  |  | ||||||
| from django.db.models.query_utils import Q | from django.db.models.query_utils import Q | ||||||
| from django.utils.translation import gettext_lazy as _ |  | ||||||
| from django_dramatiq_postgres.middleware import CurrentTask |  | ||||||
| from dramatiq.actor import actor |  | ||||||
| from guardian.shortcuts import get_anonymous_user | from guardian.shortcuts import get_anonymous_user | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
|  | from authentik.core.expression.exceptions import PropertyMappingExpressionException | ||||||
| from authentik.core.models import User | from authentik.core.models import User | ||||||
| from authentik.events.models import ( | from authentik.events.models import ( | ||||||
|     Event, |     Event, | ||||||
|     Notification, |     Notification, | ||||||
|     NotificationRule, |     NotificationRule, | ||||||
|     NotificationTransport, |     NotificationTransport, | ||||||
|  |     NotificationTransportError, | ||||||
|  |     TaskStatus, | ||||||
| ) | ) | ||||||
|  | from authentik.events.system_tasks import SystemTask, prefill_task | ||||||
| from authentik.policies.engine import PolicyEngine | from authentik.policies.engine import PolicyEngine | ||||||
| from authentik.policies.models import PolicyBinding, PolicyEngineMode | from authentik.policies.models import PolicyBinding, PolicyEngineMode | ||||||
| from authentik.tasks.models import Task | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Dispatch new event notifications.")) | @CELERY_APP.task() | ||||||
| def event_trigger_dispatch(event_uuid: UUID): | def event_notification_handler(event_uuid: str): | ||||||
|  |     """Start task for each trigger definition""" | ||||||
|     for trigger in NotificationRule.objects.all(): |     for trigger in NotificationRule.objects.all(): | ||||||
|         event_trigger_handler.send_with_options(args=(event_uuid, trigger.name), rel_obj=trigger) |         event_trigger_handler.apply_async(args=[event_uuid, trigger.name], queue="authentik_events") | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor( | @CELERY_APP.task() | ||||||
|     description=_( | def event_trigger_handler(event_uuid: str, trigger_name: str): | ||||||
|         "Check if policies attached to NotificationRule match event " |  | ||||||
|         "and dispatch notification tasks." |  | ||||||
|     ) |  | ||||||
| ) |  | ||||||
| def event_trigger_handler(event_uuid: UUID, trigger_name: str): |  | ||||||
|     """Check if policies attached to NotificationRule match event""" |     """Check if policies attached to NotificationRule match event""" | ||||||
|     self: Task = CurrentTask.get_task() |  | ||||||
|  |  | ||||||
|     event: Event = Event.objects.filter(event_uuid=event_uuid).first() |     event: Event = Event.objects.filter(event_uuid=event_uuid).first() | ||||||
|     if not event: |     if not event: | ||||||
|         self.warning("event doesn't exist yet or anymore", event_uuid=event_uuid) |         LOGGER.warning("event doesn't exist yet or anymore", event_uuid=event_uuid) | ||||||
|         return |         return | ||||||
|  |  | ||||||
|     trigger: NotificationRule | None = NotificationRule.objects.filter(name=trigger_name).first() |     trigger: NotificationRule | None = NotificationRule.objects.filter(name=trigger_name).first() | ||||||
|     if not trigger: |     if not trigger: | ||||||
|         return |         return | ||||||
| @ -78,27 +70,34 @@ def event_trigger_handler(event_uuid: UUID, trigger_name: str): | |||||||
|  |  | ||||||
|     LOGGER.debug("e(trigger): event trigger matched", trigger=trigger) |     LOGGER.debug("e(trigger): event trigger matched", trigger=trigger) | ||||||
|     # Create the notification objects |     # Create the notification objects | ||||||
|     count = 0 |  | ||||||
|     for transport in trigger.transports.all(): |     for transport in trigger.transports.all(): | ||||||
|         for user in trigger.destination_users(event): |         for user in trigger.destination_users(event): | ||||||
|             notification_transport.send_with_options( |             LOGGER.debug("created notification") | ||||||
|                 args=( |             notification_transport.apply_async( | ||||||
|  |                 args=[ | ||||||
|                     transport.pk, |                     transport.pk, | ||||||
|                     event.pk, |                     str(event.pk), | ||||||
|                     user.pk, |                     user.pk, | ||||||
|                     trigger.pk, |                     str(trigger.pk), | ||||||
|                 ), |                 ], | ||||||
|                 rel_obj=transport, |                 queue="authentik_events", | ||||||
|             ) |             ) | ||||||
|             count += 1 |  | ||||||
|             if transport.send_once: |             if transport.send_once: | ||||||
|                 break |                 break | ||||||
|     self.info(f"Created {count} notification tasks") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Send notification.")) | @CELERY_APP.task( | ||||||
| def notification_transport(transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str): |     bind=True, | ||||||
|  |     autoretry_for=(NotificationTransportError,), | ||||||
|  |     retry_backoff=True, | ||||||
|  |     base=SystemTask, | ||||||
|  | ) | ||||||
|  | def notification_transport( | ||||||
|  |     self: SystemTask, transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str | ||||||
|  | ): | ||||||
|     """Send notification over specified transport""" |     """Send notification over specified transport""" | ||||||
|  |     self.save_on_success = False | ||||||
|  |     try: | ||||||
|         event = Event.objects.filter(pk=event_pk).first() |         event = Event.objects.filter(pk=event_pk).first() | ||||||
|         if not event: |         if not event: | ||||||
|             return |             return | ||||||
| @ -111,13 +110,17 @@ def notification_transport(transport_pk: int, event_pk: str, user_pk: int, trigg | |||||||
|         notification = Notification( |         notification = Notification( | ||||||
|             severity=trigger.severity, body=event.summary, event=event, user=user |             severity=trigger.severity, body=event.summary, event=event, user=user | ||||||
|         ) |         ) | ||||||
|     transport: NotificationTransport = NotificationTransport.objects.filter(pk=transport_pk).first() |         transport = NotificationTransport.objects.filter(pk=transport_pk).first() | ||||||
|         if not transport: |         if not transport: | ||||||
|             return |             return | ||||||
|         transport.send(notification) |         transport.send(notification) | ||||||
|  |         self.set_status(TaskStatus.SUCCESSFUL) | ||||||
|  |     except (NotificationTransportError, PropertyMappingExpressionException) as exc: | ||||||
|  |         self.set_error(exc) | ||||||
|  |         raise exc | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Cleanup events for GDPR compliance.")) | @CELERY_APP.task() | ||||||
| def gdpr_cleanup(user_pk: int): | def gdpr_cleanup(user_pk: int): | ||||||
|     """cleanup events from gdpr_compliance""" |     """cleanup events from gdpr_compliance""" | ||||||
|     events = Event.objects.filter(user__pk=user_pk) |     events = Event.objects.filter(user__pk=user_pk) | ||||||
| @ -125,12 +128,12 @@ def gdpr_cleanup(user_pk: int): | |||||||
|     events.delete() |     events.delete() | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Cleanup seen notifications and notifications whose event expired.")) | @CELERY_APP.task(bind=True, base=SystemTask) | ||||||
| def notification_cleanup(): | @prefill_task | ||||||
|  | def notification_cleanup(self: SystemTask): | ||||||
|     """Cleanup seen notifications and notifications whose event expired.""" |     """Cleanup seen notifications and notifications whose event expired.""" | ||||||
|     self: Task = CurrentTask.get_task() |  | ||||||
|     notifications = Notification.objects.filter(Q(event=None) | Q(seen=True)) |     notifications = Notification.objects.filter(Q(event=None) | Q(seen=True)) | ||||||
|     amount = notifications.count() |     amount = notifications.count() | ||||||
|     notifications.delete() |     notifications.delete() | ||||||
|     LOGGER.debug("Expired notifications", amount=amount) |     LOGGER.debug("Expired notifications", amount=amount) | ||||||
|     self.info(f"Expired {amount} Notifications") |     self.set_status(TaskStatus.SUCCESSFUL, f"Expired {amount} Notifications") | ||||||
|  | |||||||
							
								
								
									
										103
									
								
								authentik/events/tests/test_tasks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								authentik/events/tests/test_tasks.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,103 @@ | |||||||
|  | """Test Monitored tasks""" | ||||||
|  |  | ||||||
|  | from json import loads | ||||||
|  |  | ||||||
|  | from django.urls import reverse | ||||||
|  | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
|  | from authentik.core.tasks import clean_expired_models | ||||||
|  | from authentik.core.tests.utils import create_test_admin_user | ||||||
|  | from authentik.events.models import SystemTask as DBSystemTask | ||||||
|  | from authentik.events.models import TaskStatus | ||||||
|  | from authentik.events.system_tasks import SystemTask | ||||||
|  | from authentik.lib.generators import generate_id | ||||||
|  | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestSystemTasks(APITestCase): | ||||||
|  |     """Test Monitored tasks""" | ||||||
|  |  | ||||||
|  |     def setUp(self): | ||||||
|  |         super().setUp() | ||||||
|  |         self.user = create_test_admin_user() | ||||||
|  |         self.client.force_login(self.user) | ||||||
|  |  | ||||||
|  |     def test_failed_successful_remove_state(self): | ||||||
|  |         """Test that a task with `save_on_success` set to `False` that failed saves | ||||||
|  |         a state, and upon successful completion will delete the state""" | ||||||
|  |         should_fail = True | ||||||
|  |         uid = generate_id() | ||||||
|  |  | ||||||
|  |         @CELERY_APP.task( | ||||||
|  |             bind=True, | ||||||
|  |             base=SystemTask, | ||||||
|  |         ) | ||||||
|  |         def test_task(self: SystemTask): | ||||||
|  |             self.save_on_success = False | ||||||
|  |             self.set_uid(uid) | ||||||
|  |             self.set_status(TaskStatus.ERROR if should_fail else TaskStatus.SUCCESSFUL) | ||||||
|  |  | ||||||
|  |         # First test successful run | ||||||
|  |         should_fail = False | ||||||
|  |         test_task.delay().get() | ||||||
|  |         self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first()) | ||||||
|  |  | ||||||
|  |         # Then test failed | ||||||
|  |         should_fail = True | ||||||
|  |         test_task.delay().get() | ||||||
|  |         task = DBSystemTask.objects.filter(name="test_task", uid=uid).first() | ||||||
|  |         self.assertEqual(task.status, TaskStatus.ERROR) | ||||||
|  |  | ||||||
|  |         # Then after that, the state should be removed | ||||||
|  |         should_fail = False | ||||||
|  |         test_task.delay().get() | ||||||
|  |         self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first()) | ||||||
|  |  | ||||||
|  |     def test_tasks(self): | ||||||
|  |         """Test Task API""" | ||||||
|  |         clean_expired_models.delay().get() | ||||||
|  |         response = self.client.get(reverse("authentik_api:systemtask-list")) | ||||||
|  |         self.assertEqual(response.status_code, 200) | ||||||
|  |         body = loads(response.content) | ||||||
|  |         self.assertTrue(any(task["name"] == "clean_expired_models" for task in body["results"])) | ||||||
|  |  | ||||||
|  |     def test_tasks_single(self): | ||||||
|  |         """Test Task API (read single)""" | ||||||
|  |         clean_expired_models.delay().get() | ||||||
|  |         task = DBSystemTask.objects.filter(name="clean_expired_models").first() | ||||||
|  |         response = self.client.get( | ||||||
|  |             reverse( | ||||||
|  |                 "authentik_api:systemtask-detail", | ||||||
|  |                 kwargs={"pk": str(task.pk)}, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 200) | ||||||
|  |         body = loads(response.content) | ||||||
|  |         self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.value) | ||||||
|  |         self.assertEqual(body["name"], "clean_expired_models") | ||||||
|  |         response = self.client.get( | ||||||
|  |             reverse("authentik_api:systemtask-detail", kwargs={"pk": "qwerqwer"}) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 404) | ||||||
|  |  | ||||||
|  |     def test_tasks_run(self): | ||||||
|  |         """Test Task API (run)""" | ||||||
|  |         clean_expired_models.delay().get() | ||||||
|  |         task = DBSystemTask.objects.filter(name="clean_expired_models").first() | ||||||
|  |         response = self.client.post( | ||||||
|  |             reverse( | ||||||
|  |                 "authentik_api:systemtask-run", | ||||||
|  |                 kwargs={"pk": str(task.pk)}, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 204) | ||||||
|  |  | ||||||
|  |     def test_tasks_run_404(self): | ||||||
|  |         """Test Task API (run, 404)""" | ||||||
|  |         response = self.client.post( | ||||||
|  |             reverse( | ||||||
|  |                 "authentik_api:systemtask-run", | ||||||
|  |                 kwargs={"pk": "qwerqewrqrqewrqewr"}, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 404) | ||||||
| @ -5,11 +5,13 @@ from authentik.events.api.notification_mappings import NotificationWebhookMappin | |||||||
| from authentik.events.api.notification_rules import NotificationRuleViewSet | from authentik.events.api.notification_rules import NotificationRuleViewSet | ||||||
| from authentik.events.api.notification_transports import NotificationTransportViewSet | from authentik.events.api.notification_transports import NotificationTransportViewSet | ||||||
| from authentik.events.api.notifications import NotificationViewSet | from authentik.events.api.notifications import NotificationViewSet | ||||||
|  | from authentik.events.api.tasks import SystemTaskViewSet | ||||||
|  |  | ||||||
| api_urlpatterns = [ | api_urlpatterns = [ | ||||||
|     ("events/events", EventViewSet), |     ("events/events", EventViewSet), | ||||||
|     ("events/notifications", NotificationViewSet), |     ("events/notifications", NotificationViewSet), | ||||||
|     ("events/transports", NotificationTransportViewSet), |     ("events/transports", NotificationTransportViewSet), | ||||||
|     ("events/rules", NotificationRuleViewSet), |     ("events/rules", NotificationRuleViewSet), | ||||||
|  |     ("events/system_tasks", SystemTaskViewSet), | ||||||
|     ("propertymappings/notification", NotificationWebhookMappingViewSet), |     ("propertymappings/notification", NotificationWebhookMappingViewSet), | ||||||
| ] | ] | ||||||
|  | |||||||
| @ -41,7 +41,6 @@ REDIS_ENV_KEYS = [ | |||||||
| # Old key -> new key | # Old key -> new key | ||||||
| DEPRECATIONS = { | DEPRECATIONS = { | ||||||
|     "geoip": "events.context_processors.geoip", |     "geoip": "events.context_processors.geoip", | ||||||
|     "worker.concurrency": "worker.processes", |  | ||||||
|     "redis.broker_url": "broker.url", |     "redis.broker_url": "broker.url", | ||||||
|     "redis.broker_transport_options": "broker.transport_options", |     "redis.broker_transport_options": "broker.transport_options", | ||||||
|     "redis.cache_timeout": "cache.timeout", |     "redis.cache_timeout": "cache.timeout", | ||||||
|  | |||||||
| @ -21,10 +21,6 @@ def start_debug_server(**kwargs) -> bool: | |||||||
|  |  | ||||||
|     listen: str = CONFIG.get("listen.listen_debug_py", "127.0.0.1:9901") |     listen: str = CONFIG.get("listen.listen_debug_py", "127.0.0.1:9901") | ||||||
|     host, _, port = listen.rpartition(":") |     host, _, port = listen.rpartition(":") | ||||||
|     try: |  | ||||||
|     debugpy.listen((host, int(port)), **kwargs)  # nosec |     debugpy.listen((host, int(port)), **kwargs)  # nosec | ||||||
|     except RuntimeError: |  | ||||||
|         LOGGER.warning("Could not start debug server. Continuing without") |  | ||||||
|         return False |  | ||||||
|     LOGGER.debug("Starting debug server", host=host, port=port) |     LOGGER.debug("Starting debug server", host=host, port=port) | ||||||
|     return True |     return True | ||||||
|  | |||||||
| @ -157,14 +157,7 @@ web: | |||||||
|   path: / |   path: / | ||||||
|  |  | ||||||
| worker: | worker: | ||||||
|   processes: 2 |   concurrency: 2 | ||||||
|   threads: 1 |  | ||||||
|   consumer_listen_timeout: "seconds=30" |  | ||||||
|   task_max_retries: 20 |  | ||||||
|   task_default_time_limit: "minutes=10" |  | ||||||
|   task_purge_interval: "days=1" |  | ||||||
|   task_expiration: "days=30" |  | ||||||
|   scheduler_interval: "seconds=60" |  | ||||||
|  |  | ||||||
| storage: | storage: | ||||||
|   media: |   media: | ||||||
|  | |||||||
| @ -88,6 +88,7 @@ def get_logger_config(): | |||||||
|         "authentik": global_level, |         "authentik": global_level, | ||||||
|         "django": "WARNING", |         "django": "WARNING", | ||||||
|         "django.request": "ERROR", |         "django.request": "ERROR", | ||||||
|  |         "celery": "WARNING", | ||||||
|         "selenium": "WARNING", |         "selenium": "WARNING", | ||||||
|         "docker": "WARNING", |         "docker": "WARNING", | ||||||
|         "urllib3": "WARNING", |         "urllib3": "WARNING", | ||||||
|  | |||||||
| @ -3,6 +3,8 @@ | |||||||
| from asyncio.exceptions import CancelledError | from asyncio.exceptions import CancelledError | ||||||
| from typing import Any | from typing import Any | ||||||
|  |  | ||||||
|  | from billiard.exceptions import SoftTimeLimitExceeded, WorkerLostError | ||||||
|  | from celery.exceptions import CeleryError | ||||||
| from channels_redis.core import ChannelFull | from channels_redis.core import ChannelFull | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError | from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError | ||||||
| @ -20,6 +22,7 @@ from sentry_sdk import HttpTransport, get_current_scope | |||||||
| from sentry_sdk import init as sentry_sdk_init | from sentry_sdk import init as sentry_sdk_init | ||||||
| from sentry_sdk.api import set_tag | from sentry_sdk.api import set_tag | ||||||
| from sentry_sdk.integrations.argv import ArgvIntegration | from sentry_sdk.integrations.argv import ArgvIntegration | ||||||
|  | from sentry_sdk.integrations.celery import CeleryIntegration | ||||||
| from sentry_sdk.integrations.django import DjangoIntegration | from sentry_sdk.integrations.django import DjangoIntegration | ||||||
| from sentry_sdk.integrations.redis import RedisIntegration | from sentry_sdk.integrations.redis import RedisIntegration | ||||||
| from sentry_sdk.integrations.socket import SocketIntegration | from sentry_sdk.integrations.socket import SocketIntegration | ||||||
| @ -68,6 +71,10 @@ ignored_classes = ( | |||||||
|     LocalProtocolError, |     LocalProtocolError, | ||||||
|     # rest_framework error |     # rest_framework error | ||||||
|     APIException, |     APIException, | ||||||
|  |     # celery errors | ||||||
|  |     WorkerLostError, | ||||||
|  |     CeleryError, | ||||||
|  |     SoftTimeLimitExceeded, | ||||||
|     # custom baseclass |     # custom baseclass | ||||||
|     SentryIgnoredException, |     SentryIgnoredException, | ||||||
|     # ldap errors |     # ldap errors | ||||||
| @ -108,6 +115,7 @@ def sentry_init(**sentry_init_kwargs): | |||||||
|             ArgvIntegration(), |             ArgvIntegration(), | ||||||
|             StdlibIntegration(), |             StdlibIntegration(), | ||||||
|             DjangoIntegration(transaction_style="function_name", cache_spans=True), |             DjangoIntegration(transaction_style="function_name", cache_spans=True), | ||||||
|  |             CeleryIntegration(), | ||||||
|             RedisIntegration(), |             RedisIntegration(), | ||||||
|             ThreadingIntegration(propagate_hub=True), |             ThreadingIntegration(propagate_hub=True), | ||||||
|             SocketIntegration(), |             SocketIntegration(), | ||||||
| @ -152,11 +160,14 @@ def before_send(event: dict, hint: dict) -> dict | None: | |||||||
|             return None |             return None | ||||||
|     if "logger" in event: |     if "logger" in event: | ||||||
|         if event["logger"] in [ |         if event["logger"] in [ | ||||||
|  |             "kombu", | ||||||
|             "asyncio", |             "asyncio", | ||||||
|             "multiprocessing", |             "multiprocessing", | ||||||
|             "django_redis", |             "django_redis", | ||||||
|             "django.security.DisallowedHost", |             "django.security.DisallowedHost", | ||||||
|             "django_redis.cache", |             "django_redis.cache", | ||||||
|  |             "celery.backends.redis", | ||||||
|  |             "celery.worker", | ||||||
|             "paramiko.transport", |             "paramiko.transport", | ||||||
|         ]: |         ]: | ||||||
|             return None |             return None | ||||||
|  | |||||||
| @ -1,12 +0,0 @@ | |||||||
| from rest_framework.fields import BooleanField, ChoiceField, DateTimeField |  | ||||||
|  |  | ||||||
| from authentik.core.api.utils import PassiveSerializer |  | ||||||
| from authentik.tasks.models import TaskStatus |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SyncStatusSerializer(PassiveSerializer): |  | ||||||
|     """Provider/source sync status""" |  | ||||||
|  |  | ||||||
|     is_running = BooleanField() |  | ||||||
|     last_successful_sync = DateTimeField(required=False) |  | ||||||
|     last_sync_status = ChoiceField(required=False, choices=TaskStatus.choices) |  | ||||||
| @ -1,7 +1,7 @@ | |||||||
| """Sync constants""" | """Sync constants""" | ||||||
|  |  | ||||||
| PAGE_SIZE = 100 | PAGE_SIZE = 100 | ||||||
| PAGE_TIMEOUT_MS = 60 * 60 * 0.5 * 1000  # Half an hour | PAGE_TIMEOUT = 60 * 60 * 0.5  # Half an hour | ||||||
| HTTP_CONFLICT = 409 | HTTP_CONFLICT = 409 | ||||||
| HTTP_NO_CONTENT = 204 | HTTP_NO_CONTENT = 204 | ||||||
| HTTP_SERVICE_UNAVAILABLE = 503 | HTTP_SERVICE_UNAVAILABLE = 503 | ||||||
|  | |||||||
| @ -1,5 +1,7 @@ | |||||||
| from dramatiq.actor import Actor | from celery import Task | ||||||
| from drf_spectacular.utils import extend_schema | from django.utils.text import slugify | ||||||
|  | from drf_spectacular.utils import OpenApiResponse, extend_schema | ||||||
|  | from guardian.shortcuts import get_objects_for_user | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
| from rest_framework.fields import BooleanField, CharField, ChoiceField | from rest_framework.fields import BooleanField, CharField, ChoiceField | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| @ -7,12 +9,18 @@ from rest_framework.response import Response | |||||||
|  |  | ||||||
| from authentik.core.api.utils import ModelSerializer, PassiveSerializer | from authentik.core.api.utils import ModelSerializer, PassiveSerializer | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.events.logs import LogEventSerializer | from authentik.events.api.tasks import SystemTaskSerializer | ||||||
| from authentik.lib.sync.api import SyncStatusSerializer | from authentik.events.logs import LogEvent, LogEventSerializer | ||||||
| from authentik.lib.sync.outgoing.models import OutgoingSyncProvider | from authentik.lib.sync.outgoing.models import OutgoingSyncProvider | ||||||
| from authentik.lib.utils.reflection import class_to_path | from authentik.lib.utils.reflection import class_to_path | ||||||
| from authentik.rbac.filters import ObjectFilter | from authentik.rbac.filters import ObjectFilter | ||||||
| from authentik.tasks.models import Task, TaskStatus |  | ||||||
|  |  | ||||||
|  | class SyncStatusSerializer(PassiveSerializer): | ||||||
|  |     """Provider sync status""" | ||||||
|  |  | ||||||
|  |     is_running = BooleanField(read_only=True) | ||||||
|  |     tasks = SystemTaskSerializer(many=True, read_only=True) | ||||||
|  |  | ||||||
|  |  | ||||||
| class SyncObjectSerializer(PassiveSerializer): | class SyncObjectSerializer(PassiveSerializer): | ||||||
| @ -37,10 +45,15 @@ class SyncObjectResultSerializer(PassiveSerializer): | |||||||
| class OutgoingSyncProviderStatusMixin: | class OutgoingSyncProviderStatusMixin: | ||||||
|     """Common API Endpoints for Outgoing sync providers""" |     """Common API Endpoints for Outgoing sync providers""" | ||||||
|  |  | ||||||
|     sync_task: Actor |     sync_single_task: type[Task] = None | ||||||
|     sync_objects_task: Actor |     sync_objects_task: type[Task] = None | ||||||
|  |  | ||||||
|     @extend_schema(responses={200: SyncStatusSerializer()}) |     @extend_schema( | ||||||
|  |         responses={ | ||||||
|  |             200: SyncStatusSerializer(), | ||||||
|  |             404: OpenApiResponse(description="Task not found"), | ||||||
|  |         } | ||||||
|  |     ) | ||||||
|     @action( |     @action( | ||||||
|         methods=["GET"], |         methods=["GET"], | ||||||
|         detail=True, |         detail=True, | ||||||
| @ -51,39 +64,18 @@ class OutgoingSyncProviderStatusMixin: | |||||||
|     def sync_status(self, request: Request, pk: int) -> Response: |     def sync_status(self, request: Request, pk: int) -> Response: | ||||||
|         """Get provider's sync status""" |         """Get provider's sync status""" | ||||||
|         provider: OutgoingSyncProvider = self.get_object() |         provider: OutgoingSyncProvider = self.get_object() | ||||||
|  |         tasks = list( | ||||||
|         status = {} |             get_objects_for_user(request.user, "authentik_events.view_systemtask").filter( | ||||||
|  |                 name=self.sync_single_task.__name__, | ||||||
|  |                 uid=slugify(provider.name), | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|         with provider.sync_lock as lock_acquired: |         with provider.sync_lock as lock_acquired: | ||||||
|  |             status = { | ||||||
|  |                 "tasks": tasks, | ||||||
|                 # If we could not acquire the lock, it means a task is using it, and thus is running |                 # If we could not acquire the lock, it means a task is using it, and thus is running | ||||||
|             status["is_running"] = not lock_acquired |                 "is_running": not lock_acquired, | ||||||
|  |             } | ||||||
|         sync_schedule = None |  | ||||||
|         for schedule in provider.schedules.all(): |  | ||||||
|             if schedule.actor_name == self.sync_task.actor_name: |  | ||||||
|                 sync_schedule = schedule |  | ||||||
|  |  | ||||||
|         if not sync_schedule: |  | ||||||
|             return Response(SyncStatusSerializer(status).data) |  | ||||||
|  |  | ||||||
|         last_task: Task = ( |  | ||||||
|             sync_schedule.tasks.exclude( |  | ||||||
|                 aggregated_status__in=(TaskStatus.CONSUMED, TaskStatus.QUEUED) |  | ||||||
|             ) |  | ||||||
|             .order_by("-mtime") |  | ||||||
|             .first() |  | ||||||
|         ) |  | ||||||
|         last_successful_task: Task = ( |  | ||||||
|             sync_schedule.tasks.filter(aggregated_status__in=(TaskStatus.DONE, TaskStatus.INFO)) |  | ||||||
|             .order_by("-mtime") |  | ||||||
|             .first() |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         if last_task: |  | ||||||
|             status["last_sync_status"] = last_task.aggregated_status |  | ||||||
|         if last_successful_task: |  | ||||||
|             status["last_successful_sync"] = last_successful_task.mtime |  | ||||||
|  |  | ||||||
|         return Response(SyncStatusSerializer(status).data) |         return Response(SyncStatusSerializer(status).data) | ||||||
|  |  | ||||||
|     @extend_schema( |     @extend_schema( | ||||||
| @ -102,20 +94,14 @@ class OutgoingSyncProviderStatusMixin: | |||||||
|         provider: OutgoingSyncProvider = self.get_object() |         provider: OutgoingSyncProvider = self.get_object() | ||||||
|         params = SyncObjectSerializer(data=request.data) |         params = SyncObjectSerializer(data=request.data) | ||||||
|         params.is_valid(raise_exception=True) |         params.is_valid(raise_exception=True) | ||||||
|         msg = self.sync_objects_task.send_with_options( |         res: list[LogEvent] = self.sync_objects_task.delay( | ||||||
|             kwargs={ |             params.validated_data["sync_object_model"], | ||||||
|                 "object_type": params.validated_data["sync_object_model"], |             page=1, | ||||||
|                 "page": 1, |             provider_pk=provider.pk, | ||||||
|                 "provider_pk": provider.pk, |             pk=params.validated_data["sync_object_id"], | ||||||
|                 "override_dry_run": params.validated_data["override_dry_run"], |             override_dry_run=params.validated_data["override_dry_run"], | ||||||
|                 "pk": params.validated_data["sync_object_id"], |         ).get() | ||||||
|             }, |         return Response(SyncObjectResultSerializer(instance={"messages": res}).data) | ||||||
|             rel_obj=provider, |  | ||||||
|         ) |  | ||||||
|         msg.get_result(block=True) |  | ||||||
|         task: Task = msg.options["task"] |  | ||||||
|         task.refresh_from_db() |  | ||||||
|         return Response(SyncObjectResultSerializer(instance={"messages": task._messages}).data) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class OutgoingSyncConnectionCreateMixin: | class OutgoingSyncConnectionCreateMixin: | ||||||
|  | |||||||
| @ -1,18 +1,12 @@ | |||||||
| from typing import Any, Self | from typing import Any, Self | ||||||
|  |  | ||||||
| import pglock | import pglock | ||||||
| from django.core.paginator import Paginator |  | ||||||
| from django.db import connection, models | from django.db import connection, models | ||||||
| from django.db.models import Model, QuerySet, TextChoices | from django.db.models import Model, QuerySet, TextChoices | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from dramatiq.actor import Actor |  | ||||||
|  |  | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT_MS |  | ||||||
| from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient | from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
| from authentik.tasks.schedules.lib import ScheduleSpec |  | ||||||
| from authentik.tasks.schedules.models import ScheduledModel |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class OutgoingSyncDeleteAction(TextChoices): | class OutgoingSyncDeleteAction(TextChoices): | ||||||
| @ -24,7 +18,7 @@ class OutgoingSyncDeleteAction(TextChoices): | |||||||
|     SUSPEND = "suspend" |     SUSPEND = "suspend" | ||||||
|  |  | ||||||
|  |  | ||||||
| class OutgoingSyncProvider(ScheduledModel, Model): | class OutgoingSyncProvider(Model): | ||||||
|     """Base abstract models for providers implementing outgoing sync""" |     """Base abstract models for providers implementing outgoing sync""" | ||||||
|  |  | ||||||
|     dry_run = models.BooleanField( |     dry_run = models.BooleanField( | ||||||
| @ -45,19 +39,6 @@ class OutgoingSyncProvider(ScheduledModel, Model): | |||||||
|     def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]: |     def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]: | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def get_paginator[T: User | Group](self, type: type[T]) -> Paginator: |  | ||||||
|         return Paginator(self.get_object_qs(type), PAGE_SIZE) |  | ||||||
|  |  | ||||||
|     def get_object_sync_time_limit_ms[T: User | Group](self, type: type[T]) -> int: |  | ||||||
|         num_pages: int = self.get_paginator(type).num_pages |  | ||||||
|         return int(num_pages * PAGE_TIMEOUT_MS * 1.5) |  | ||||||
|  |  | ||||||
|     def get_sync_time_limit_ms(self) -> int: |  | ||||||
|         return int( |  | ||||||
|             (self.get_object_sync_time_limit_ms(User) + self.get_object_sync_time_limit_ms(Group)) |  | ||||||
|             * 1.5 |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def sync_lock(self) -> pglock.advisory: |     def sync_lock(self) -> pglock.advisory: | ||||||
|         """Postgres lock for syncing to prevent multiple parallel syncs happening""" |         """Postgres lock for syncing to prevent multiple parallel syncs happening""" | ||||||
| @ -66,22 +47,3 @@ class OutgoingSyncProvider(ScheduledModel, Model): | |||||||
|             timeout=0, |             timeout=0, | ||||||
|             side_effect=pglock.Return, |             side_effect=pglock.Return, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def sync_actor(self) -> Actor: |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def schedule_specs(self) -> list[ScheduleSpec]: |  | ||||||
|         return [ |  | ||||||
|             ScheduleSpec( |  | ||||||
|                 actor=self.sync_actor, |  | ||||||
|                 uid=self.pk, |  | ||||||
|                 args=(self.pk,), |  | ||||||
|                 options={ |  | ||||||
|                     "time_limit": self.get_sync_time_limit_ms(), |  | ||||||
|                 }, |  | ||||||
|                 send_on_save=True, |  | ||||||
|                 crontab=f"{fqdn_rand(self.pk)} */4 * * *", |  | ||||||
|             ), |  | ||||||
|         ] |  | ||||||
|  | |||||||
| @ -1,8 +1,12 @@ | |||||||
|  | from collections.abc import Callable | ||||||
|  |  | ||||||
|  | from django.core.paginator import Paginator | ||||||
| from django.db.models import Model | from django.db.models import Model | ||||||
|  | from django.db.models.query import Q | ||||||
| from django.db.models.signals import m2m_changed, post_save, pre_delete | from django.db.models.signals import m2m_changed, post_save, pre_delete | ||||||
| from dramatiq.actor import Actor |  | ||||||
|  |  | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
|  | from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT | ||||||
| from authentik.lib.sync.outgoing.base import Direction | from authentik.lib.sync.outgoing.base import Direction | ||||||
| from authentik.lib.sync.outgoing.models import OutgoingSyncProvider | from authentik.lib.sync.outgoing.models import OutgoingSyncProvider | ||||||
| from authentik.lib.utils.reflection import class_to_path | from authentik.lib.utils.reflection import class_to_path | ||||||
| @ -10,30 +14,45 @@ from authentik.lib.utils.reflection import class_to_path | |||||||
|  |  | ||||||
| def register_signals( | def register_signals( | ||||||
|     provider_type: type[OutgoingSyncProvider], |     provider_type: type[OutgoingSyncProvider], | ||||||
|     task_sync_direct_dispatch: Actor[[str, str | int, str], None], |     task_sync_single: Callable[[int], None], | ||||||
|     task_sync_m2m_dispatch: Actor[[str, str, list[str], bool], None], |     task_sync_direct: Callable[[int], None], | ||||||
|  |     task_sync_m2m: Callable[[int], None], | ||||||
| ): | ): | ||||||
|     """Register sync signals""" |     """Register sync signals""" | ||||||
|     uid = class_to_path(provider_type) |     uid = class_to_path(provider_type) | ||||||
|  |  | ||||||
|  |     def post_save_provider(sender: type[Model], instance: OutgoingSyncProvider, created: bool, **_): | ||||||
|  |         """Trigger sync when Provider is saved""" | ||||||
|  |         users_paginator = Paginator(instance.get_object_qs(User), PAGE_SIZE) | ||||||
|  |         groups_paginator = Paginator(instance.get_object_qs(Group), PAGE_SIZE) | ||||||
|  |         soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT | ||||||
|  |         time_limit = soft_time_limit * 1.5 | ||||||
|  |         task_sync_single.apply_async( | ||||||
|  |             (instance.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     post_save.connect(post_save_provider, provider_type, dispatch_uid=uid, weak=False) | ||||||
|  |  | ||||||
|     def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_): |     def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_): | ||||||
|         """Post save handler""" |         """Post save handler""" | ||||||
|         task_sync_direct_dispatch.send( |         if not provider_type.objects.filter( | ||||||
|             class_to_path(instance.__class__), |             Q(backchannel_application__isnull=False) | Q(application__isnull=False) | ||||||
|             instance.pk, |         ).exists(): | ||||||
|             Direction.add.value, |             return | ||||||
|         ) |         task_sync_direct.delay(class_to_path(instance.__class__), instance.pk, Direction.add.value) | ||||||
|  |  | ||||||
|     post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False) |     post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False) | ||||||
|     post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False) |     post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False) | ||||||
|  |  | ||||||
|     def model_pre_delete(sender: type[Model], instance: User | Group, **_): |     def model_pre_delete(sender: type[Model], instance: User | Group, **_): | ||||||
|         """Pre-delete handler""" |         """Pre-delete handler""" | ||||||
|         task_sync_direct_dispatch.send( |         if not provider_type.objects.filter( | ||||||
|             class_to_path(instance.__class__), |             Q(backchannel_application__isnull=False) | Q(application__isnull=False) | ||||||
|             instance.pk, |         ).exists(): | ||||||
|             Direction.remove.value, |             return | ||||||
|         ) |         task_sync_direct.delay( | ||||||
|  |             class_to_path(instance.__class__), instance.pk, Direction.remove.value | ||||||
|  |         ).get(propagate=False) | ||||||
|  |  | ||||||
|     pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False) |     pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False) | ||||||
|     pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False) |     pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False) | ||||||
| @ -44,6 +63,16 @@ def register_signals( | |||||||
|         """Sync group membership""" |         """Sync group membership""" | ||||||
|         if action not in ["post_add", "post_remove"]: |         if action not in ["post_add", "post_remove"]: | ||||||
|             return |             return | ||||||
|         task_sync_m2m_dispatch.send(instance.pk, action, list(pk_set), reverse) |         if not provider_type.objects.filter( | ||||||
|  |             Q(backchannel_application__isnull=False) | Q(application__isnull=False) | ||||||
|  |         ).exists(): | ||||||
|  |             return | ||||||
|  |         # reverse: instance is a Group, pk_set is a list of user pks | ||||||
|  |         # non-reverse: instance is a User, pk_set is a list of groups | ||||||
|  |         if reverse: | ||||||
|  |             task_sync_m2m.delay(str(instance.pk), action, list(pk_set)) | ||||||
|  |         else: | ||||||
|  |             for group_pk in pk_set: | ||||||
|  |                 task_sync_m2m.delay(group_pk, action, [instance.pk]) | ||||||
|  |  | ||||||
|     m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False) |     m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False) | ||||||
|  | |||||||
| @ -1,17 +1,23 @@ | |||||||
|  | from collections.abc import Callable | ||||||
|  | from dataclasses import asdict | ||||||
|  |  | ||||||
|  | from celery import group | ||||||
|  | from celery.exceptions import Retry | ||||||
|  | from celery.result import allow_join_result | ||||||
| from django.core.paginator import Paginator | from django.core.paginator import Paginator | ||||||
| from django.db.models import Model, QuerySet | from django.db.models import Model, QuerySet | ||||||
| from django.db.models.query import Q | from django.db.models.query import Q | ||||||
| from django.utils.text import slugify | from django.utils.text import slugify | ||||||
| from django_dramatiq_postgres.middleware import CurrentTask | from django.utils.translation import gettext_lazy as _ | ||||||
| from dramatiq.actor import Actor |  | ||||||
| from dramatiq.composition import group |  | ||||||
| from dramatiq.errors import Retry |  | ||||||
| from structlog.stdlib import BoundLogger, get_logger | from structlog.stdlib import BoundLogger, get_logger | ||||||
|  |  | ||||||
| from authentik.core.expression.exceptions import SkipObjectException | from authentik.core.expression.exceptions import SkipObjectException | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
|  | from authentik.events.logs import LogEvent | ||||||
|  | from authentik.events.models import TaskStatus | ||||||
|  | from authentik.events.system_tasks import SystemTask | ||||||
| from authentik.events.utils import sanitize_item | from authentik.events.utils import sanitize_item | ||||||
| from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT_MS | from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT | ||||||
| from authentik.lib.sync.outgoing.base import Direction | from authentik.lib.sync.outgoing.base import Direction | ||||||
| from authentik.lib.sync.outgoing.exceptions import ( | from authentik.lib.sync.outgoing.exceptions import ( | ||||||
|     BadRequestSyncException, |     BadRequestSyncException, | ||||||
| @ -21,12 +27,11 @@ from authentik.lib.sync.outgoing.exceptions import ( | |||||||
| ) | ) | ||||||
| from authentik.lib.sync.outgoing.models import OutgoingSyncProvider | from authentik.lib.sync.outgoing.models import OutgoingSyncProvider | ||||||
| from authentik.lib.utils.reflection import class_to_path, path_to_class | from authentik.lib.utils.reflection import class_to_path, path_to_class | ||||||
| from authentik.tasks.models import Task |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SyncTasks: | class SyncTasks: | ||||||
|     """Container for all sync 'tasks' (this class doesn't actually contain |     """Container for all sync 'tasks' (this class doesn't actually contain celery | ||||||
|     tasks due to dramatiq's magic, however exposes a number of functions to be called from tasks)""" |     tasks due to celery's magic, however exposes a number of functions to be called from tasks)""" | ||||||
|  |  | ||||||
|     logger: BoundLogger |     logger: BoundLogger | ||||||
|  |  | ||||||
| @ -34,104 +39,107 @@ class SyncTasks: | |||||||
|         super().__init__() |         super().__init__() | ||||||
|         self._provider_model = provider_model |         self._provider_model = provider_model | ||||||
|  |  | ||||||
|     def sync_paginator( |     def sync_all(self, single_sync: Callable[[int], None]): | ||||||
|         self, |         for provider in self._provider_model.objects.filter( | ||||||
|         current_task: Task, |             Q(backchannel_application__isnull=False) | Q(application__isnull=False) | ||||||
|         provider: OutgoingSyncProvider, |  | ||||||
|         sync_objects: Actor[[str, int, int, bool], None], |  | ||||||
|         paginator: Paginator, |  | ||||||
|         object_type: type[User | Group], |  | ||||||
|         **options, |  | ||||||
|         ): |         ): | ||||||
|         tasks = [] |             self.trigger_single_task(provider, single_sync) | ||||||
|         for page in paginator.page_range: |  | ||||||
|             page_sync = sync_objects.message_with_options( |  | ||||||
|                 args=(class_to_path(object_type), page, provider.pk), |  | ||||||
|                 time_limit=PAGE_TIMEOUT_MS, |  | ||||||
|                 # Assign tasks to the same schedule as the current one |  | ||||||
|                 rel_obj=current_task.rel_obj, |  | ||||||
|                 **options, |  | ||||||
|             ) |  | ||||||
|             tasks.append(page_sync) |  | ||||||
|         return tasks |  | ||||||
|  |  | ||||||
|     def sync( |     def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]): | ||||||
|  |         """Wrapper single sync task that correctly sets time limits based | ||||||
|  |         on the amount of objects that will be synced""" | ||||||
|  |         users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE) | ||||||
|  |         groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE) | ||||||
|  |         soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT | ||||||
|  |         time_limit = soft_time_limit * 1.5 | ||||||
|  |         return sync_task.apply_async( | ||||||
|  |             (provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def sync_single( | ||||||
|         self, |         self, | ||||||
|  |         task: SystemTask, | ||||||
|         provider_pk: int, |         provider_pk: int, | ||||||
|         sync_objects: Actor[[str, int, int, bool], None], |         sync_objects: Callable[[int, int], list[str]], | ||||||
|     ): |     ): | ||||||
|         task: Task = CurrentTask.get_task() |  | ||||||
|         self.logger = get_logger().bind( |         self.logger = get_logger().bind( | ||||||
|             provider_type=class_to_path(self._provider_model), |             provider_type=class_to_path(self._provider_model), | ||||||
|             provider_pk=provider_pk, |             provider_pk=provider_pk, | ||||||
|         ) |         ) | ||||||
|         provider: OutgoingSyncProvider = self._provider_model.objects.filter( |         provider = self._provider_model.objects.filter( | ||||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False), |             Q(backchannel_application__isnull=False) | Q(application__isnull=False), | ||||||
|             pk=provider_pk, |             pk=provider_pk, | ||||||
|         ).first() |         ).first() | ||||||
|         if not provider: |         if not provider: | ||||||
|             task.warning("No provider found. Is it assigned to an application?") |  | ||||||
|             return |             return | ||||||
|         task.set_uid(slugify(provider.name)) |         task.set_uid(slugify(provider.name)) | ||||||
|         task.info("Starting full provider sync") |         messages = [] | ||||||
|  |         messages.append(_("Starting full provider sync")) | ||||||
|         self.logger.debug("Starting provider sync") |         self.logger.debug("Starting provider sync") | ||||||
|         with provider.sync_lock as lock_acquired: |         users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE) | ||||||
|  |         groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE) | ||||||
|  |         with allow_join_result(), provider.sync_lock as lock_acquired: | ||||||
|             if not lock_acquired: |             if not lock_acquired: | ||||||
|                 task.info("Synchronization is already running. Skipping.") |  | ||||||
|                 self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name) |                 self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name) | ||||||
|                 return |                 return | ||||||
|             try: |             try: | ||||||
|                 users_tasks = group( |                 messages.append(_("Syncing users")) | ||||||
|                     self.sync_paginator( |                 user_results = ( | ||||||
|                         current_task=task, |                     group( | ||||||
|                         provider=provider, |                         [ | ||||||
|                         sync_objects=sync_objects, |                             sync_objects.signature( | ||||||
|                         paginator=provider.get_paginator(User), |                                 args=(class_to_path(User), page, provider_pk), | ||||||
|                         object_type=User, |                                 time_limit=PAGE_TIMEOUT, | ||||||
|  |                                 soft_time_limit=PAGE_TIMEOUT, | ||||||
|                             ) |                             ) | ||||||
|  |                             for page in users_paginator.page_range | ||||||
|  |                         ] | ||||||
|                     ) |                     ) | ||||||
|                 group_tasks = group( |                     .apply_async() | ||||||
|                     self.sync_paginator( |                     .get() | ||||||
|                         current_task=task, |  | ||||||
|                         provider=provider, |  | ||||||
|                         sync_objects=sync_objects, |  | ||||||
|                         paginator=provider.get_paginator(Group), |  | ||||||
|                         object_type=Group, |  | ||||||
|                 ) |                 ) | ||||||
|  |                 for result in user_results: | ||||||
|  |                     for msg in result: | ||||||
|  |                         messages.append(LogEvent(**msg)) | ||||||
|  |                 messages.append(_("Syncing groups")) | ||||||
|  |                 group_results = ( | ||||||
|  |                     group( | ||||||
|  |                         [ | ||||||
|  |                             sync_objects.signature( | ||||||
|  |                                 args=(class_to_path(Group), page, provider_pk), | ||||||
|  |                                 time_limit=PAGE_TIMEOUT, | ||||||
|  |                                 soft_time_limit=PAGE_TIMEOUT, | ||||||
|                             ) |                             ) | ||||||
|                 users_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(User)) |                             for page in groups_paginator.page_range | ||||||
|                 group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group)) |                         ] | ||||||
|  |                     ) | ||||||
|  |                     .apply_async() | ||||||
|  |                     .get() | ||||||
|  |                 ) | ||||||
|  |                 for result in group_results: | ||||||
|  |                     for msg in result: | ||||||
|  |                         messages.append(LogEvent(**msg)) | ||||||
|             except TransientSyncException as exc: |             except TransientSyncException as exc: | ||||||
|                 self.logger.warning("transient sync exception", exc=exc) |                 self.logger.warning("transient sync exception", exc=exc) | ||||||
|                 task.warning("Sync encountered a transient exception. Retrying", exc=exc) |                 raise task.retry(exc=exc) from exc | ||||||
|                 raise Retry() from exc |  | ||||||
|             except StopSync as exc: |             except StopSync as exc: | ||||||
|                 task.error(exc) |                 task.set_error(exc) | ||||||
|                 return |                 return | ||||||
|  |         task.set_status(TaskStatus.SUCCESSFUL, *messages) | ||||||
|  |  | ||||||
|     def sync_objects( |     def sync_objects( | ||||||
|         self, |         self, object_type: str, page: int, provider_pk: int, override_dry_run=False, **filter | ||||||
|         object_type: str, |  | ||||||
|         page: int, |  | ||||||
|         provider_pk: int, |  | ||||||
|         override_dry_run=False, |  | ||||||
|         **filter, |  | ||||||
|     ): |     ): | ||||||
|         task: Task = CurrentTask.get_task() |  | ||||||
|         _object_type: type[Model] = path_to_class(object_type) |         _object_type: type[Model] = path_to_class(object_type) | ||||||
|         self.logger = get_logger().bind( |         self.logger = get_logger().bind( | ||||||
|             provider_type=class_to_path(self._provider_model), |             provider_type=class_to_path(self._provider_model), | ||||||
|             provider_pk=provider_pk, |             provider_pk=provider_pk, | ||||||
|             object_type=object_type, |             object_type=object_type, | ||||||
|         ) |         ) | ||||||
|         provider: OutgoingSyncProvider = self._provider_model.objects.filter( |         messages = [] | ||||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False), |         provider = self._provider_model.objects.filter(pk=provider_pk).first() | ||||||
|             pk=provider_pk, |  | ||||||
|         ).first() |  | ||||||
|         if not provider: |         if not provider: | ||||||
|             task.warning("No provider found. Is it assigned to an application?") |             return messages | ||||||
|             return |  | ||||||
|         task.set_uid(slugify(provider.name)) |  | ||||||
|         # Override dry run mode if requested, however don't save the provider |         # Override dry run mode if requested, however don't save the provider | ||||||
|         # so that scheduled sync tasks still run in dry_run mode |         # so that scheduled sync tasks still run in dry_run mode | ||||||
|         if override_dry_run: |         if override_dry_run: | ||||||
| @ -139,13 +147,25 @@ class SyncTasks: | |||||||
|         try: |         try: | ||||||
|             client = provider.client_for_model(_object_type) |             client = provider.client_for_model(_object_type) | ||||||
|         except TransientSyncException: |         except TransientSyncException: | ||||||
|             return |             return messages | ||||||
|         paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE) |         paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE) | ||||||
|         if client.can_discover: |         if client.can_discover: | ||||||
|             self.logger.debug("starting discover") |             self.logger.debug("starting discover") | ||||||
|             client.discover() |             client.discover() | ||||||
|         self.logger.debug("starting sync for page", page=page) |         self.logger.debug("starting sync for page", page=page) | ||||||
|         task.info(f"Syncing page {page} or {_object_type._meta.verbose_name_plural}") |         messages.append( | ||||||
|  |             asdict( | ||||||
|  |                 LogEvent( | ||||||
|  |                     _( | ||||||
|  |                         "Syncing page {page} of {object_type}".format( | ||||||
|  |                             page=page, object_type=_object_type._meta.verbose_name_plural | ||||||
|  |                         ) | ||||||
|  |                     ), | ||||||
|  |                     log_level="info", | ||||||
|  |                     logger=f"{provider._meta.verbose_name}@{object_type}", | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|         for obj in paginator.page(page).object_list: |         for obj in paginator.page(page).object_list: | ||||||
|             obj: Model |             obj: Model | ||||||
|             try: |             try: | ||||||
| @ -154,58 +174,89 @@ class SyncTasks: | |||||||
|                 self.logger.debug("skipping object due to SkipObject", obj=obj) |                 self.logger.debug("skipping object due to SkipObject", obj=obj) | ||||||
|                 continue |                 continue | ||||||
|             except DryRunRejected as exc: |             except DryRunRejected as exc: | ||||||
|                 task.info( |                 messages.append( | ||||||
|                     "Dropping mutating request due to dry run", |                     asdict( | ||||||
|                     obj=sanitize_item(obj), |                         LogEvent( | ||||||
|                     method=exc.method, |                             _("Dropping mutating request due to dry run"), | ||||||
|                     url=exc.url, |                             log_level="info", | ||||||
|                     body=exc.body, |                             logger=f"{provider._meta.verbose_name}@{object_type}", | ||||||
|  |                             attributes={ | ||||||
|  |                                 "obj": sanitize_item(obj), | ||||||
|  |                                 "method": exc.method, | ||||||
|  |                                 "url": exc.url, | ||||||
|  |                                 "body": exc.body, | ||||||
|  |                             }, | ||||||
|  |                         ) | ||||||
|  |                     ) | ||||||
|                 ) |                 ) | ||||||
|             except BadRequestSyncException as exc: |             except BadRequestSyncException as exc: | ||||||
|                 self.logger.warning("failed to sync object", exc=exc, obj=obj) |                 self.logger.warning("failed to sync object", exc=exc, obj=obj) | ||||||
|                 task.warning( |                 messages.append( | ||||||
|                     f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to error: {str(exc)}", |                     asdict( | ||||||
|                     arguments=exc.args[1:], |                         LogEvent( | ||||||
|                     obj=sanitize_item(obj), |                             _( | ||||||
|  |                                 ( | ||||||
|  |                                     "Failed to sync {object_type} {object_name} " | ||||||
|  |                                     "due to error: {error}" | ||||||
|  |                                 ).format_map( | ||||||
|  |                                     { | ||||||
|  |                                         "object_type": obj._meta.verbose_name, | ||||||
|  |                                         "object_name": str(obj), | ||||||
|  |                                         "error": str(exc), | ||||||
|  |                                     } | ||||||
|  |                                 ) | ||||||
|  |                             ), | ||||||
|  |                             log_level="warning", | ||||||
|  |                             logger=f"{provider._meta.verbose_name}@{object_type}", | ||||||
|  |                             attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)}, | ||||||
|  |                         ) | ||||||
|  |                     ) | ||||||
|                 ) |                 ) | ||||||
|             except TransientSyncException as exc: |             except TransientSyncException as exc: | ||||||
|                 self.logger.warning("failed to sync object", exc=exc, user=obj) |                 self.logger.warning("failed to sync object", exc=exc, user=obj) | ||||||
|                 task.warning( |                 messages.append( | ||||||
|                     f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to " |                     asdict( | ||||||
|                     "transient error: {str(exc)}", |                         LogEvent( | ||||||
|                     obj=sanitize_item(obj), |                             _( | ||||||
|  |                                 ( | ||||||
|  |                                     "Failed to sync {object_type} {object_name} " | ||||||
|  |                                     "due to transient error: {error}" | ||||||
|  |                                 ).format_map( | ||||||
|  |                                     { | ||||||
|  |                                         "object_type": obj._meta.verbose_name, | ||||||
|  |                                         "object_name": str(obj), | ||||||
|  |                                         "error": str(exc), | ||||||
|  |                                     } | ||||||
|  |                                 ) | ||||||
|  |                             ), | ||||||
|  |                             log_level="warning", | ||||||
|  |                             logger=f"{provider._meta.verbose_name}@{object_type}", | ||||||
|  |                             attributes={"obj": sanitize_item(obj)}, | ||||||
|  |                         ) | ||||||
|  |                     ) | ||||||
|                 ) |                 ) | ||||||
|             except StopSync as exc: |             except StopSync as exc: | ||||||
|                 self.logger.warning("Stopping sync", exc=exc) |                 self.logger.warning("Stopping sync", exc=exc) | ||||||
|                 task.warning( |                 messages.append( | ||||||
|                     f"Stopping sync due to error: {exc.detail()}", |                     asdict( | ||||||
|                     obj=sanitize_item(obj), |                         LogEvent( | ||||||
|  |                             _( | ||||||
|  |                                 "Stopping sync due to error: {error}".format_map( | ||||||
|  |                                     { | ||||||
|  |                                         "error": exc.detail(), | ||||||
|  |                                     } | ||||||
|  |                                 ) | ||||||
|  |                             ), | ||||||
|  |                             log_level="warning", | ||||||
|  |                             logger=f"{provider._meta.verbose_name}@{object_type}", | ||||||
|  |                             attributes={"obj": sanitize_item(obj)}, | ||||||
|  |                         ) | ||||||
|  |                     ) | ||||||
|                 ) |                 ) | ||||||
|                 break |                 break | ||||||
|  |         return messages | ||||||
|  |  | ||||||
|     def sync_signal_direct_dispatch( |     def sync_signal_direct(self, model: str, pk: str | int, raw_op: str): | ||||||
|         self, |  | ||||||
|         task_sync_signal_direct: Actor[[str, str | int, int, str], None], |  | ||||||
|         model: str, |  | ||||||
|         pk: str | int, |  | ||||||
|         raw_op: str, |  | ||||||
|     ): |  | ||||||
|         for provider in self._provider_model.objects.filter( |  | ||||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False) |  | ||||||
|         ): |  | ||||||
|             task_sync_signal_direct.send_with_options( |  | ||||||
|                 args=(model, pk, provider.pk, raw_op), |  | ||||||
|                 rel_obj=provider, |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|     def sync_signal_direct( |  | ||||||
|         self, |  | ||||||
|         model: str, |  | ||||||
|         pk: str | int, |  | ||||||
|         provider_pk: int, |  | ||||||
|         raw_op: str, |  | ||||||
|     ): |  | ||||||
|         task: Task = CurrentTask.get_task() |  | ||||||
|         self.logger = get_logger().bind( |         self.logger = get_logger().bind( | ||||||
|             provider_type=class_to_path(self._provider_model), |             provider_type=class_to_path(self._provider_model), | ||||||
|         ) |         ) | ||||||
| @ -213,25 +264,20 @@ class SyncTasks: | |||||||
|         instance = model_class.objects.filter(pk=pk).first() |         instance = model_class.objects.filter(pk=pk).first() | ||||||
|         if not instance: |         if not instance: | ||||||
|             return |             return | ||||||
|         provider: OutgoingSyncProvider = self._provider_model.objects.filter( |  | ||||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False), |  | ||||||
|             pk=provider_pk, |  | ||||||
|         ).first() |  | ||||||
|         if not provider: |  | ||||||
|             task.warning("No provider found. Is it assigned to an application?") |  | ||||||
|             return |  | ||||||
|         task.set_uid(slugify(provider.name)) |  | ||||||
|         operation = Direction(raw_op) |         operation = Direction(raw_op) | ||||||
|  |         for provider in self._provider_model.objects.filter( | ||||||
|  |             Q(backchannel_application__isnull=False) | Q(application__isnull=False) | ||||||
|  |         ): | ||||||
|             client = provider.client_for_model(instance.__class__) |             client = provider.client_for_model(instance.__class__) | ||||||
|             # Check if the object is allowed within the provider's restrictions |             # Check if the object is allowed within the provider's restrictions | ||||||
|             queryset = provider.get_object_qs(instance.__class__) |             queryset = provider.get_object_qs(instance.__class__) | ||||||
|             if not queryset: |             if not queryset: | ||||||
|             return |                 continue | ||||||
|  |  | ||||||
|             # The queryset we get from the provider must include the instance we've got given |             # The queryset we get from the provider must include the instance we've got given | ||||||
|             # otherwise ignore this provider |             # otherwise ignore this provider | ||||||
|             if not queryset.filter(pk=instance.pk).exists(): |             if not queryset.filter(pk=instance.pk).exists(): | ||||||
|             return |                 continue | ||||||
|  |  | ||||||
|             try: |             try: | ||||||
|                 if operation == Direction.add: |                 if operation == Direction.add: | ||||||
| @ -241,66 +287,28 @@ class SyncTasks: | |||||||
|             except TransientSyncException as exc: |             except TransientSyncException as exc: | ||||||
|                 raise Retry() from exc |                 raise Retry() from exc | ||||||
|             except SkipObjectException: |             except SkipObjectException: | ||||||
|             return |                 continue | ||||||
|             except DryRunRejected as exc: |             except DryRunRejected as exc: | ||||||
|                 self.logger.info("Rejected dry-run event", exc=exc) |                 self.logger.info("Rejected dry-run event", exc=exc) | ||||||
|             except StopSync as exc: |             except StopSync as exc: | ||||||
|                 self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) |                 self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) | ||||||
|  |  | ||||||
|     def sync_signal_m2m_dispatch( |     def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]): | ||||||
|         self, |  | ||||||
|         task_sync_signal_m2m: Actor[[str, int, str, list[int]], None], |  | ||||||
|         instance_pk: str, |  | ||||||
|         action: str, |  | ||||||
|         pk_set: list[int], |  | ||||||
|         reverse: bool, |  | ||||||
|     ): |  | ||||||
|         for provider in self._provider_model.objects.filter( |  | ||||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False) |  | ||||||
|         ): |  | ||||||
|             # reverse: instance is a Group, pk_set is a list of user pks |  | ||||||
|             # non-reverse: instance is a User, pk_set is a list of groups |  | ||||||
|             if reverse: |  | ||||||
|                 task_sync_signal_m2m.send_with_options( |  | ||||||
|                     args=(instance_pk, provider.pk, action, list(pk_set)), |  | ||||||
|                     rel_obj=provider, |  | ||||||
|                 ) |  | ||||||
|             else: |  | ||||||
|                 for pk in pk_set: |  | ||||||
|                     task_sync_signal_m2m.send_with_options( |  | ||||||
|                         args=(pk, provider.pk, action, [instance_pk]), |  | ||||||
|                         rel_obj=provider, |  | ||||||
|                     ) |  | ||||||
|  |  | ||||||
|     def sync_signal_m2m( |  | ||||||
|         self, |  | ||||||
|         group_pk: str, |  | ||||||
|         provider_pk: int, |  | ||||||
|         action: str, |  | ||||||
|         pk_set: list[int], |  | ||||||
|     ): |  | ||||||
|         task: Task = CurrentTask.get_task() |  | ||||||
|         self.logger = get_logger().bind( |         self.logger = get_logger().bind( | ||||||
|             provider_type=class_to_path(self._provider_model), |             provider_type=class_to_path(self._provider_model), | ||||||
|         ) |         ) | ||||||
|         group = Group.objects.filter(pk=group_pk).first() |         group = Group.objects.filter(pk=group_pk).first() | ||||||
|         if not group: |         if not group: | ||||||
|             return |             return | ||||||
|         provider: OutgoingSyncProvider = self._provider_model.objects.filter( |         for provider in self._provider_model.objects.filter( | ||||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False), |             Q(backchannel_application__isnull=False) | Q(application__isnull=False) | ||||||
|             pk=provider_pk, |         ): | ||||||
|         ).first() |  | ||||||
|         if not provider: |  | ||||||
|             task.warning("No provider found. Is it assigned to an application?") |  | ||||||
|             return |  | ||||||
|         task.set_uid(slugify(provider.name)) |  | ||||||
|  |  | ||||||
|             # Check if the object is allowed within the provider's restrictions |             # Check if the object is allowed within the provider's restrictions | ||||||
|             queryset: QuerySet = provider.get_object_qs(Group) |             queryset: QuerySet = provider.get_object_qs(Group) | ||||||
|             # The queryset we get from the provider must include the instance we've got given |             # The queryset we get from the provider must include the instance we've got given | ||||||
|             # otherwise ignore this provider |             # otherwise ignore this provider | ||||||
|             if not queryset.filter(pk=group_pk).exists(): |             if not queryset.filter(pk=group_pk).exists(): | ||||||
|             return |                 continue | ||||||
|  |  | ||||||
|             client = provider.client_for_model(Group) |             client = provider.client_for_model(Group) | ||||||
|             try: |             try: | ||||||
| @ -313,7 +321,7 @@ class SyncTasks: | |||||||
|             except TransientSyncException as exc: |             except TransientSyncException as exc: | ||||||
|                 raise Retry() from exc |                 raise Retry() from exc | ||||||
|             except SkipObjectException: |             except SkipObjectException: | ||||||
|             return |                 continue | ||||||
|             except DryRunRejected as exc: |             except DryRunRejected as exc: | ||||||
|                 self.logger.info("Rejected dry-run event", exc=exc) |                 self.logger.info("Rejected dry-run event", exc=exc) | ||||||
|             except StopSync as exc: |             except StopSync as exc: | ||||||
|  | |||||||
| @ -5,8 +5,6 @@ from structlog.stdlib import get_logger | |||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
| from authentik.tasks.schedules.lib import ScheduleSpec |  | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| @ -62,27 +60,3 @@ class AuthentikOutpostConfig(ManagedAppConfig): | |||||||
|                 outpost.save() |                 outpost.save() | ||||||
|         else: |         else: | ||||||
|             Outpost.objects.filter(managed=MANAGED_OUTPOST).delete() |             Outpost.objects.filter(managed=MANAGED_OUTPOST).delete() | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def tenant_schedule_specs(self) -> list[ScheduleSpec]: |  | ||||||
|         from authentik.outposts.tasks import outpost_token_ensurer |  | ||||||
|  |  | ||||||
|         return [ |  | ||||||
|             ScheduleSpec( |  | ||||||
|                 actor=outpost_token_ensurer, |  | ||||||
|                 crontab=f"{fqdn_rand('outpost_token_ensurer')} */8 * * *", |  | ||||||
|             ), |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def global_schedule_specs(self) -> list[ScheduleSpec]: |  | ||||||
|         from authentik.outposts.tasks import outpost_connection_discovery |  | ||||||
|  |  | ||||||
|         return [ |  | ||||||
|             ScheduleSpec( |  | ||||||
|                 actor=outpost_connection_discovery, |  | ||||||
|                 crontab=f"{fqdn_rand('outpost_connection_discovery')} */8 * * *", |  | ||||||
|                 send_on_startup=True, |  | ||||||
|                 paused=not CONFIG.get_bool("outposts.discover"), |  | ||||||
|             ), |  | ||||||
|         ] |  | ||||||
|  | |||||||
| @ -101,13 +101,7 @@ class KubernetesController(BaseController): | |||||||
|             all_logs = [] |             all_logs = [] | ||||||
|             for reconcile_key in self.reconcile_order: |             for reconcile_key in self.reconcile_order: | ||||||
|                 if reconcile_key in self.outpost.config.kubernetes_disabled_components: |                 if reconcile_key in self.outpost.config.kubernetes_disabled_components: | ||||||
|                     all_logs.append( |                     all_logs += [f"{reconcile_key.title()}: Disabled"] | ||||||
|                         LogEvent( |  | ||||||
|                             log_level="info", |  | ||||||
|                             event=f"{reconcile_key.title()}: Disabled", |  | ||||||
|                             logger=str(type(self)), |  | ||||||
|                         ) |  | ||||||
|                     ) |  | ||||||
|                     continue |                     continue | ||||||
|                 with capture_logs() as logs: |                 with capture_logs() as logs: | ||||||
|                     reconciler_cls = self.reconcilers.get(reconcile_key) |                     reconciler_cls = self.reconcilers.get(reconcile_key) | ||||||
| @ -140,13 +134,7 @@ class KubernetesController(BaseController): | |||||||
|             all_logs = [] |             all_logs = [] | ||||||
|             for reconcile_key in self.reconcile_order: |             for reconcile_key in self.reconcile_order: | ||||||
|                 if reconcile_key in self.outpost.config.kubernetes_disabled_components: |                 if reconcile_key in self.outpost.config.kubernetes_disabled_components: | ||||||
|                     all_logs.append( |                     all_logs += [f"{reconcile_key.title()}: Disabled"] | ||||||
|                         LogEvent( |  | ||||||
|                             log_level="info", |  | ||||||
|                             event=f"{reconcile_key.title()}: Disabled", |  | ||||||
|                             logger=str(type(self)), |  | ||||||
|                         ) |  | ||||||
|                     ) |  | ||||||
|                     continue |                     continue | ||||||
|                 with capture_logs() as logs: |                 with capture_logs() as logs: | ||||||
|                     reconciler_cls = self.reconcilers.get(reconcile_key) |                     reconciler_cls = self.reconcilers.get(reconcile_key) | ||||||
|  | |||||||
| @ -36,10 +36,7 @@ from authentik.lib.config import CONFIG | |||||||
| from authentik.lib.models import InheritanceForeignKey, SerializerModel | from authentik.lib.models import InheritanceForeignKey, SerializerModel | ||||||
| from authentik.lib.sentry import SentryIgnoredException | from authentik.lib.sentry import SentryIgnoredException | ||||||
| from authentik.lib.utils.errors import exception_to_string | from authentik.lib.utils.errors import exception_to_string | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
| from authentik.outposts.controllers.k8s.utils import get_namespace | from authentik.outposts.controllers.k8s.utils import get_namespace | ||||||
| from authentik.tasks.schedules.lib import ScheduleSpec |  | ||||||
| from authentik.tasks.schedules.models import ScheduledModel |  | ||||||
|  |  | ||||||
| OUR_VERSION = parse(__version__) | OUR_VERSION = parse(__version__) | ||||||
| OUTPOST_HELLO_INTERVAL = 10 | OUTPOST_HELLO_INTERVAL = 10 | ||||||
| @ -118,7 +115,7 @@ class OutpostServiceConnectionState: | |||||||
|     healthy: bool |     healthy: bool | ||||||
|  |  | ||||||
|  |  | ||||||
| class OutpostServiceConnection(ScheduledModel, models.Model): | class OutpostServiceConnection(models.Model): | ||||||
|     """Connection details for an Outpost Controller, like Docker or Kubernetes""" |     """Connection details for an Outpost Controller, like Docker or Kubernetes""" | ||||||
|  |  | ||||||
|     uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) |     uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) | ||||||
| @ -148,11 +145,11 @@ class OutpostServiceConnection(ScheduledModel, models.Model): | |||||||
|     @property |     @property | ||||||
|     def state(self) -> OutpostServiceConnectionState: |     def state(self) -> OutpostServiceConnectionState: | ||||||
|         """Get state of service connection""" |         """Get state of service connection""" | ||||||
|         from authentik.outposts.tasks import outpost_service_connection_monitor |         from authentik.outposts.tasks import outpost_service_connection_state | ||||||
|  |  | ||||||
|         state = cache.get(self.state_key, None) |         state = cache.get(self.state_key, None) | ||||||
|         if not state: |         if not state: | ||||||
|             outpost_service_connection_monitor.send_with_options(args=(self.pk), rel_obj=self) |             outpost_service_connection_state.delay(self.pk) | ||||||
|             return OutpostServiceConnectionState("", False) |             return OutpostServiceConnectionState("", False) | ||||||
|         return state |         return state | ||||||
|  |  | ||||||
| @ -163,20 +160,6 @@ class OutpostServiceConnection(ScheduledModel, models.Model): | |||||||
|         # since the response doesn't use the correct inheritance |         # since the response doesn't use the correct inheritance | ||||||
|         return "" |         return "" | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def schedule_specs(self) -> list[ScheduleSpec]: |  | ||||||
|         from authentik.outposts.tasks import outpost_service_connection_monitor |  | ||||||
|  |  | ||||||
|         return [ |  | ||||||
|             ScheduleSpec( |  | ||||||
|                 actor=outpost_service_connection_monitor, |  | ||||||
|                 uid=self.pk, |  | ||||||
|                 args=(self.pk,), |  | ||||||
|                 crontab="3-59/15 * * * *", |  | ||||||
|                 send_on_save=True, |  | ||||||
|             ), |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DockerServiceConnection(SerializerModel, OutpostServiceConnection): | class DockerServiceConnection(SerializerModel, OutpostServiceConnection): | ||||||
|     """Service Connection to a Docker endpoint""" |     """Service Connection to a Docker endpoint""" | ||||||
| @ -261,7 +244,7 @@ class KubernetesServiceConnection(SerializerModel, OutpostServiceConnection): | |||||||
|         return "ak-service-connection-kubernetes-form" |         return "ak-service-connection-kubernetes-form" | ||||||
|  |  | ||||||
|  |  | ||||||
| class Outpost(ScheduledModel, SerializerModel, ManagedModel): | class Outpost(SerializerModel, ManagedModel): | ||||||
|     """Outpost instance which manages a service user and token""" |     """Outpost instance which manages a service user and token""" | ||||||
|  |  | ||||||
|     uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) |     uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) | ||||||
| @ -315,21 +298,6 @@ class Outpost(ScheduledModel, SerializerModel, ManagedModel): | |||||||
|         """Username for service user""" |         """Username for service user""" | ||||||
|         return f"ak-outpost-{self.uuid.hex}" |         return f"ak-outpost-{self.uuid.hex}" | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def schedule_specs(self) -> list[ScheduleSpec]: |  | ||||||
|         from authentik.outposts.tasks import outpost_controller |  | ||||||
|  |  | ||||||
|         return [ |  | ||||||
|             ScheduleSpec( |  | ||||||
|                 actor=outpost_controller, |  | ||||||
|                 uid=self.pk, |  | ||||||
|                 args=(self.pk,), |  | ||||||
|                 kwargs={"action": "up", "from_cache": False}, |  | ||||||
|                 crontab=f"{fqdn_rand('outpost_controller')} */4 * * *", |  | ||||||
|                 send_on_save=True, |  | ||||||
|             ), |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|     def build_user_permissions(self, user: User): |     def build_user_permissions(self, user: User): | ||||||
|         """Create per-object and global permissions for outpost service-account""" |         """Create per-object and global permissions for outpost service-account""" | ||||||
|         # To ensure the user only has the correct permissions, we delete all of them and re-add |         # To ensure the user only has the correct permissions, we delete all of them and re-add | ||||||
|  | |||||||
							
								
								
									
										28
									
								
								authentik/outposts/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								authentik/outposts/settings.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,28 @@ | |||||||
|  | """Outposts Settings""" | ||||||
|  |  | ||||||
|  | from celery.schedules import crontab | ||||||
|  |  | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
|  |  | ||||||
|  | CELERY_BEAT_SCHEDULE = { | ||||||
|  |     "outposts_controller": { | ||||||
|  |         "task": "authentik.outposts.tasks.outpost_controller_all", | ||||||
|  |         "schedule": crontab(minute=fqdn_rand("outposts_controller"), hour="*/4"), | ||||||
|  |         "options": {"queue": "authentik_scheduled"}, | ||||||
|  |     }, | ||||||
|  |     "outposts_service_connection_check": { | ||||||
|  |         "task": "authentik.outposts.tasks.outpost_service_connection_monitor", | ||||||
|  |         "schedule": crontab(minute="3-59/15"), | ||||||
|  |         "options": {"queue": "authentik_scheduled"}, | ||||||
|  |     }, | ||||||
|  |     "outpost_token_ensurer": { | ||||||
|  |         "task": "authentik.outposts.tasks.outpost_token_ensurer", | ||||||
|  |         "schedule": crontab(minute=fqdn_rand("outpost_token_ensurer"), hour="*/8"), | ||||||
|  |         "options": {"queue": "authentik_scheduled"}, | ||||||
|  |     }, | ||||||
|  |     "outpost_connection_discovery": { | ||||||
|  |         "task": "authentik.outposts.tasks.outpost_connection_discovery", | ||||||
|  |         "schedule": crontab(minute=fqdn_rand("outpost_connection_discovery"), hour="*/8"), | ||||||
|  |         "options": {"queue": "authentik_scheduled"}, | ||||||
|  |     }, | ||||||
|  | } | ||||||
| @ -1,6 +1,7 @@ | |||||||
| """authentik outpost signals""" | """authentik outpost signals""" | ||||||
|  |  | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
|  | from django.db.models import Model | ||||||
| from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save | from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| @ -8,19 +9,27 @@ from structlog.stdlib import get_logger | |||||||
| from authentik.brands.models import Brand | from authentik.brands.models import Brand | ||||||
| from authentik.core.models import AuthenticatedSession, Provider | from authentik.core.models import AuthenticatedSession, Provider | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.outposts.models import Outpost, OutpostModel, OutpostServiceConnection | from authentik.lib.utils.reflection import class_to_path | ||||||
|  | from authentik.outposts.models import Outpost, OutpostServiceConnection | ||||||
| from authentik.outposts.tasks import ( | from authentik.outposts.tasks import ( | ||||||
|     CACHE_KEY_OUTPOST_DOWN, |     CACHE_KEY_OUTPOST_DOWN, | ||||||
|     outpost_controller, |     outpost_controller, | ||||||
|     outpost_send_update, |     outpost_post_save, | ||||||
|     outpost_session_end, |     outpost_session_end, | ||||||
| ) | ) | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  | UPDATE_TRIGGERING_MODELS = ( | ||||||
|  |     Outpost, | ||||||
|  |     OutpostServiceConnection, | ||||||
|  |     Provider, | ||||||
|  |     CertificateKeyPair, | ||||||
|  |     Brand, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_save, sender=Outpost) | @receiver(pre_save, sender=Outpost) | ||||||
| def outpost_pre_save(sender, instance: Outpost, **_): | def pre_save_outpost(sender, instance: Outpost, **_): | ||||||
|     """Pre-save checks for an outpost, if the name or config.kubernetes_namespace changes, |     """Pre-save checks for an outpost, if the name or config.kubernetes_namespace changes, | ||||||
|     we call down and then wait for the up after save""" |     we call down and then wait for the up after save""" | ||||||
|     old_instances = Outpost.objects.filter(pk=instance.pk) |     old_instances = Outpost.objects.filter(pk=instance.pk) | ||||||
| @ -35,89 +44,43 @@ def outpost_pre_save(sender, instance: Outpost, **_): | |||||||
|     if bool(dirty): |     if bool(dirty): | ||||||
|         LOGGER.info("Outpost needs re-deployment due to changes", instance=instance) |         LOGGER.info("Outpost needs re-deployment due to changes", instance=instance) | ||||||
|         cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance) |         cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance) | ||||||
|         outpost_controller.send_with_options( |         outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) | ||||||
|             args=(instance.pk.hex,), |  | ||||||
|             kwargs={"action": "down", "from_cache": True}, |  | ||||||
|             rel_obj=instance, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(m2m_changed, sender=Outpost.providers.through) | @receiver(m2m_changed, sender=Outpost.providers.through) | ||||||
| def outpost_m2m_changed(sender, instance: Outpost | Provider, action: str, **_): | def m2m_changed_update(sender, instance: Model, action: str, **_): | ||||||
|     """Update outpost on m2m change, when providers are added or removed""" |     """Update outpost on m2m change, when providers are added or removed""" | ||||||
|     if action not in ["post_add", "post_remove", "post_clear"]: |     if action in ["post_add", "post_remove", "post_clear"]: | ||||||
|  |         outpost_post_save.delay(class_to_path(instance.__class__), instance.pk) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(post_save) | ||||||
|  | def post_save_update(sender, instance: Model, created: bool, **_): | ||||||
|  |     """If an Outpost is saved, Ensure that token is created/updated | ||||||
|  |  | ||||||
|  |     If an OutpostModel, or a model that is somehow connected to an OutpostModel is saved, | ||||||
|  |     we send a message down the relevant OutpostModels WS connection to trigger an update""" | ||||||
|  |     if instance.__module__ == "django.db.migrations.recorder": | ||||||
|         return |         return | ||||||
|     if isinstance(instance, Outpost): |     if instance.__module__ == "__fake__": | ||||||
|         outpost_controller.send_with_options( |         return | ||||||
|             args=(instance.pk,), |     if not isinstance(instance, UPDATE_TRIGGERING_MODELS): | ||||||
|             rel_obj=instance.service_connection, |         return | ||||||
|         ) |     if isinstance(instance, Outpost) and created: | ||||||
|         outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance) |  | ||||||
|     elif isinstance(instance, OutpostModel): |  | ||||||
|         for outpost in instance.outpost_set.all(): |  | ||||||
|             outpost_controller.send_with_options( |  | ||||||
|                 args=(instance.pk,), |  | ||||||
|                 rel_obj=instance.service_connection, |  | ||||||
|             ) |  | ||||||
|             outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(post_save, sender=Outpost) |  | ||||||
| def outpost_post_save(sender, instance: Outpost, created: bool, **_): |  | ||||||
|     if created: |  | ||||||
|         LOGGER.info("New outpost saved, ensuring initial token and user are created") |         LOGGER.info("New outpost saved, ensuring initial token and user are created") | ||||||
|         _ = instance.token |         _ = instance.token | ||||||
|     outpost_controller.send_with_options(args=(instance.pk,), rel_obj=instance.service_connection) |     outpost_post_save.delay(class_to_path(instance.__class__), instance.pk) | ||||||
|     outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def outpost_related_post_save(sender, instance: OutpostServiceConnection | OutpostModel, **_): |  | ||||||
|     for outpost in instance.outpost_set.all(): |  | ||||||
|         outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| post_save.connect(outpost_related_post_save, sender=OutpostServiceConnection, weak=False) |  | ||||||
| for subclass in OutpostModel.__subclasses__(): |  | ||||||
|     post_save.connect(outpost_related_post_save, sender=subclass, weak=False) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def outpost_reverse_related_post_save(sender, instance: CertificateKeyPair | Brand, **_): |  | ||||||
|     for field in instance._meta.get_fields(): |  | ||||||
|         # Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms) |  | ||||||
|         # are used, and if it has a value |  | ||||||
|         if not hasattr(field, "related_model"): |  | ||||||
|             continue |  | ||||||
|         if not field.related_model: |  | ||||||
|             continue |  | ||||||
|         if not issubclass(field.related_model, OutpostModel): |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         field_name = f"{field.name}_set" |  | ||||||
|         if not hasattr(instance, field_name): |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         LOGGER.debug("triggering outpost update from field", field=field.name) |  | ||||||
|         # Because the Outpost Model has an M2M to Provider, |  | ||||||
|         # we have to iterate over the entire QS |  | ||||||
|         for reverse in getattr(instance, field_name).all(): |  | ||||||
|             if isinstance(reverse, OutpostModel): |  | ||||||
|                 for outpost in reverse.outpost_set.all(): |  | ||||||
|                     outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| post_save.connect(outpost_reverse_related_post_save, sender=Brand, weak=False) |  | ||||||
| post_save.connect(outpost_reverse_related_post_save, sender=CertificateKeyPair, weak=False) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_delete, sender=Outpost) | @receiver(pre_delete, sender=Outpost) | ||||||
| def outpost_pre_delete_cleanup(sender, instance: Outpost, **_): | def pre_delete_cleanup(sender, instance: Outpost, **_): | ||||||
|     """Ensure that Outpost's user is deleted (which will delete the token through cascade)""" |     """Ensure that Outpost's user is deleted (which will delete the token through cascade)""" | ||||||
|     instance.user.delete() |     instance.user.delete() | ||||||
|     cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance) |     cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance) | ||||||
|     outpost_controller.send(instance.pk.hex, action="down", from_cache=True) |     outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_delete, sender=AuthenticatedSession) | @receiver(pre_delete, sender=AuthenticatedSession) | ||||||
| def outpost_logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | ||||||
|     """Catch logout by expiring sessions being deleted""" |     """Catch logout by expiring sessions being deleted""" | ||||||
|     outpost_session_end.send(instance.session.session_key) |     outpost_session_end.delay(instance.session.session_key) | ||||||
|  | |||||||
| @ -10,17 +10,19 @@ from urllib.parse import urlparse | |||||||
| from asgiref.sync import async_to_sync | from asgiref.sync import async_to_sync | ||||||
| from channels.layers import get_channel_layer | from channels.layers import get_channel_layer | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
|  | from django.db import DatabaseError, InternalError, ProgrammingError | ||||||
|  | from django.db.models.base import Model | ||||||
| from django.utils.text import slugify | from django.utils.text import slugify | ||||||
| from django.utils.translation import gettext_lazy as _ |  | ||||||
| from django_dramatiq_postgres.middleware import CurrentTask |  | ||||||
| from docker.constants import DEFAULT_UNIX_SOCKET | from docker.constants import DEFAULT_UNIX_SOCKET | ||||||
| from dramatiq.actor import actor |  | ||||||
| from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME | from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME | ||||||
| from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION | from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| from yaml import safe_load | from yaml import safe_load | ||||||
|  |  | ||||||
|  | from authentik.events.models import TaskStatus | ||||||
|  | from authentik.events.system_tasks import SystemTask, prefill_task | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
|  | from authentik.lib.utils.reflection import path_to_class | ||||||
| from authentik.outposts.consumer import OUTPOST_GROUP | from authentik.outposts.consumer import OUTPOST_GROUP | ||||||
| from authentik.outposts.controllers.base import BaseController, ControllerException | from authentik.outposts.controllers.base import BaseController, ControllerException | ||||||
| from authentik.outposts.controllers.docker import DockerClient | from authentik.outposts.controllers.docker import DockerClient | ||||||
| @ -29,6 +31,7 @@ from authentik.outposts.models import ( | |||||||
|     DockerServiceConnection, |     DockerServiceConnection, | ||||||
|     KubernetesServiceConnection, |     KubernetesServiceConnection, | ||||||
|     Outpost, |     Outpost, | ||||||
|  |     OutpostModel, | ||||||
|     OutpostServiceConnection, |     OutpostServiceConnection, | ||||||
|     OutpostType, |     OutpostType, | ||||||
|     ServiceConnectionInvalid, |     ServiceConnectionInvalid, | ||||||
| @ -41,7 +44,7 @@ from authentik.providers.rac.controllers.docker import RACDockerController | |||||||
| from authentik.providers.rac.controllers.kubernetes import RACKubernetesController | from authentik.providers.rac.controllers.kubernetes import RACKubernetesController | ||||||
| from authentik.providers.radius.controllers.docker import RadiusDockerController | from authentik.providers.radius.controllers.docker import RadiusDockerController | ||||||
| from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController | from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController | ||||||
| from authentik.tasks.models import Task | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s" | CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s" | ||||||
| @ -80,8 +83,8 @@ def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None: | |||||||
|     return None |     return None | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Update cached state of service connection.")) | @CELERY_APP.task() | ||||||
| def outpost_service_connection_monitor(connection_pk: Any): | def outpost_service_connection_state(connection_pk: Any): | ||||||
|     """Update cached state of a service connection""" |     """Update cached state of a service connection""" | ||||||
|     connection: OutpostServiceConnection = ( |     connection: OutpostServiceConnection = ( | ||||||
|         OutpostServiceConnection.objects.filter(pk=connection_pk).select_subclasses().first() |         OutpostServiceConnection.objects.filter(pk=connection_pk).select_subclasses().first() | ||||||
| @ -105,11 +108,37 @@ def outpost_service_connection_monitor(connection_pk: Any): | |||||||
|     cache.set(connection.state_key, state, timeout=None) |     cache.set(connection.state_key, state, timeout=None) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Create/update/monitor/delete the deployment of an Outpost.")) | @CELERY_APP.task( | ||||||
| def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False): |     bind=True, | ||||||
|  |     base=SystemTask, | ||||||
|  |     throws=(DatabaseError, ProgrammingError, InternalError), | ||||||
|  | ) | ||||||
|  | @prefill_task | ||||||
|  | def outpost_service_connection_monitor(self: SystemTask): | ||||||
|  |     """Regularly check the state of Outpost Service Connections""" | ||||||
|  |     connections = OutpostServiceConnection.objects.all() | ||||||
|  |     for connection in connections.iterator(): | ||||||
|  |         outpost_service_connection_state.delay(connection.pk) | ||||||
|  |     self.set_status( | ||||||
|  |         TaskStatus.SUCCESSFUL, | ||||||
|  |         f"Successfully updated {len(connections)} connections.", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @CELERY_APP.task( | ||||||
|  |     throws=(DatabaseError, ProgrammingError, InternalError), | ||||||
|  | ) | ||||||
|  | def outpost_controller_all(): | ||||||
|  |     """Launch Controller for all Outposts which support it""" | ||||||
|  |     for outpost in Outpost.objects.exclude(service_connection=None): | ||||||
|  |         outpost_controller.delay(outpost.pk.hex, "up", from_cache=False) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @CELERY_APP.task(bind=True, base=SystemTask) | ||||||
|  | def outpost_controller( | ||||||
|  |     self: SystemTask, outpost_pk: str, action: str = "up", from_cache: bool = False | ||||||
|  | ): | ||||||
|     """Create/update/monitor/delete the deployment of an Outpost""" |     """Create/update/monitor/delete the deployment of an Outpost""" | ||||||
|     self: Task = CurrentTask.get_task() |  | ||||||
|     self.set_uid(outpost_pk) |  | ||||||
|     logs = [] |     logs = [] | ||||||
|     if from_cache: |     if from_cache: | ||||||
|         outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk) |         outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk) | ||||||
| @ -130,65 +159,125 @@ def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = F | |||||||
|             logs = getattr(controller, f"{action}_with_logs")() |             logs = getattr(controller, f"{action}_with_logs")() | ||||||
|             LOGGER.debug("-----------------Outpost Controller logs end-------------------") |             LOGGER.debug("-----------------Outpost Controller logs end-------------------") | ||||||
|     except (ControllerException, ServiceConnectionInvalid) as exc: |     except (ControllerException, ServiceConnectionInvalid) as exc: | ||||||
|         self.error(exc) |         self.set_error(exc) | ||||||
|     else: |     else: | ||||||
|         if from_cache: |         if from_cache: | ||||||
|             cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk) |             cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk) | ||||||
|         self.logs(logs) |         self.set_status(TaskStatus.SUCCESSFUL, *logs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Ensure that all Outposts have valid Service Accounts and Tokens.")) | @CELERY_APP.task(bind=True, base=SystemTask) | ||||||
| def outpost_token_ensurer(): | @prefill_task | ||||||
|     """ | def outpost_token_ensurer(self: SystemTask): | ||||||
|     Periodically ensure that all Outposts have valid Service Accounts and Tokens |     """Periodically ensure that all Outposts have valid Service Accounts | ||||||
|     """ |     and Tokens""" | ||||||
|     self: Task = CurrentTask.get_task() |  | ||||||
|     all_outposts = Outpost.objects.all() |     all_outposts = Outpost.objects.all() | ||||||
|     for outpost in all_outposts: |     for outpost in all_outposts: | ||||||
|         _ = outpost.token |         _ = outpost.token | ||||||
|         outpost.build_user_permissions(outpost.user) |         outpost.build_user_permissions(outpost.user) | ||||||
|     self.info(f"Successfully checked {len(all_outposts)} Outposts.") |     self.set_status( | ||||||
|  |         TaskStatus.SUCCESSFUL, | ||||||
|  |         f"Successfully checked {len(all_outposts)} Outposts.", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Send update to outpost")) | @CELERY_APP.task() | ||||||
| def outpost_send_update(pk: Any): | def outpost_post_save(model_class: str, model_pk: Any): | ||||||
|     """Update outpost instance""" |     """If an Outpost is saved, Ensure that token is created/updated | ||||||
|     outpost = Outpost.objects.filter(pk=pk).first() |  | ||||||
|     if not outpost: |     If an OutpostModel, or a model that is somehow connected to an OutpostModel is saved, | ||||||
|  |     we send a message down the relevant OutpostModels WS connection to trigger an update""" | ||||||
|  |     model: Model = path_to_class(model_class) | ||||||
|  |     try: | ||||||
|  |         instance = model.objects.get(pk=model_pk) | ||||||
|  |     except model.DoesNotExist: | ||||||
|  |         LOGGER.warning("Model does not exist", model=model, pk=model_pk) | ||||||
|         return |         return | ||||||
|  |  | ||||||
|  |     if isinstance(instance, Outpost): | ||||||
|  |         LOGGER.debug("Trigger reconcile for outpost", instance=instance) | ||||||
|  |         outpost_controller.delay(str(instance.pk)) | ||||||
|  |  | ||||||
|  |     if isinstance(instance, OutpostModel | Outpost): | ||||||
|  |         LOGGER.debug("triggering outpost update from outpostmodel/outpost", instance=instance) | ||||||
|  |         outpost_send_update(instance) | ||||||
|  |  | ||||||
|  |     if isinstance(instance, OutpostServiceConnection): | ||||||
|  |         LOGGER.debug("triggering ServiceConnection state update", instance=instance) | ||||||
|  |         outpost_service_connection_state.delay(str(instance.pk)) | ||||||
|  |  | ||||||
|  |     for field in instance._meta.get_fields(): | ||||||
|  |         # Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms) | ||||||
|  |         # are used, and if it has a value | ||||||
|  |         if not hasattr(field, "related_model"): | ||||||
|  |             continue | ||||||
|  |         if not field.related_model: | ||||||
|  |             continue | ||||||
|  |         if not issubclass(field.related_model, OutpostModel): | ||||||
|  |             continue | ||||||
|  |  | ||||||
|  |         field_name = f"{field.name}_set" | ||||||
|  |         if not hasattr(instance, field_name): | ||||||
|  |             continue | ||||||
|  |  | ||||||
|  |         LOGGER.debug("triggering outpost update from field", field=field.name) | ||||||
|  |         # Because the Outpost Model has an M2M to Provider, | ||||||
|  |         # we have to iterate over the entire QS | ||||||
|  |         for reverse in getattr(instance, field_name).all(): | ||||||
|  |             outpost_send_update(reverse) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def outpost_send_update(model_instance: Model): | ||||||
|  |     """Send outpost update to all registered outposts, regardless to which authentik | ||||||
|  |     instance they are connected""" | ||||||
|  |     channel_layer = get_channel_layer() | ||||||
|  |     if isinstance(model_instance, OutpostModel): | ||||||
|  |         for outpost in model_instance.outpost_set.all(): | ||||||
|  |             _outpost_single_update(outpost, channel_layer) | ||||||
|  |     elif isinstance(model_instance, Outpost): | ||||||
|  |         _outpost_single_update(model_instance, channel_layer) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _outpost_single_update(outpost: Outpost, layer=None): | ||||||
|  |     """Update outpost instances connected to a single outpost""" | ||||||
|     # Ensure token again, because this function is called when anything related to an |     # Ensure token again, because this function is called when anything related to an | ||||||
|     # OutpostModel is saved, so we can be sure permissions are right |     # OutpostModel is saved, so we can be sure permissions are right | ||||||
|     _ = outpost.token |     _ = outpost.token | ||||||
|     outpost.build_user_permissions(outpost.user) |     outpost.build_user_permissions(outpost.user) | ||||||
|  |     if not layer:  # pragma: no cover | ||||||
|         layer = get_channel_layer() |         layer = get_channel_layer() | ||||||
|     group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} |     group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} | ||||||
|     LOGGER.debug("sending update", channel=group, outpost=outpost) |     LOGGER.debug("sending update", channel=group, outpost=outpost) | ||||||
|     async_to_sync(layer.group_send)(group, {"type": "event.update"}) |     async_to_sync(layer.group_send)(group, {"type": "event.update"}) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Checks the local environment and create Service connections.")) | @CELERY_APP.task( | ||||||
| def outpost_connection_discovery(): |     base=SystemTask, | ||||||
|  |     bind=True, | ||||||
|  | ) | ||||||
|  | def outpost_connection_discovery(self: SystemTask): | ||||||
|     """Checks the local environment and create Service connections.""" |     """Checks the local environment and create Service connections.""" | ||||||
|     self: Task = CurrentTask.get_task() |     messages = [] | ||||||
|     if not CONFIG.get_bool("outposts.discover"): |     if not CONFIG.get_bool("outposts.discover"): | ||||||
|         self.info("Outpost integration discovery is disabled") |         messages.append("Outpost integration discovery is disabled") | ||||||
|  |         self.set_status(TaskStatus.SUCCESSFUL, *messages) | ||||||
|         return |         return | ||||||
|     # Explicitly check against token filename, as that's |     # Explicitly check against token filename, as that's | ||||||
|     # only present when the integration is enabled |     # only present when the integration is enabled | ||||||
|     if Path(SERVICE_TOKEN_FILENAME).exists(): |     if Path(SERVICE_TOKEN_FILENAME).exists(): | ||||||
|         self.info("Detected in-cluster Kubernetes Config") |         messages.append("Detected in-cluster Kubernetes Config") | ||||||
|         if not KubernetesServiceConnection.objects.filter(local=True).exists(): |         if not KubernetesServiceConnection.objects.filter(local=True).exists(): | ||||||
|             self.info("Created Service Connection for in-cluster") |             messages.append("Created Service Connection for in-cluster") | ||||||
|             KubernetesServiceConnection.objects.create( |             KubernetesServiceConnection.objects.create( | ||||||
|                 name="Local Kubernetes Cluster", local=True, kubeconfig={} |                 name="Local Kubernetes Cluster", local=True, kubeconfig={} | ||||||
|             ) |             ) | ||||||
|     # For development, check for the existence of a kubeconfig file |     # For development, check for the existence of a kubeconfig file | ||||||
|     kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser() |     kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser() | ||||||
|     if kubeconfig_path.exists(): |     if kubeconfig_path.exists(): | ||||||
|         self.info("Detected kubeconfig") |         messages.append("Detected kubeconfig") | ||||||
|         kubeconfig_local_name = f"k8s-{gethostname()}" |         kubeconfig_local_name = f"k8s-{gethostname()}" | ||||||
|         if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists(): |         if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists(): | ||||||
|             self.info("Creating kubeconfig Service Connection") |             messages.append("Creating kubeconfig Service Connection") | ||||||
|             with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig: |             with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig: | ||||||
|                 KubernetesServiceConnection.objects.create( |                 KubernetesServiceConnection.objects.create( | ||||||
|                     name=kubeconfig_local_name, |                     name=kubeconfig_local_name, | ||||||
| @ -197,18 +286,20 @@ def outpost_connection_discovery(): | |||||||
|     unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path |     unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path | ||||||
|     socket = Path(unix_socket_path) |     socket = Path(unix_socket_path) | ||||||
|     if socket.exists() and access(socket, R_OK): |     if socket.exists() and access(socket, R_OK): | ||||||
|         self.info("Detected local docker socket") |         messages.append("Detected local docker socket") | ||||||
|         if len(DockerServiceConnection.objects.filter(local=True)) == 0: |         if len(DockerServiceConnection.objects.filter(local=True)) == 0: | ||||||
|             self.info("Created Service Connection for docker") |             messages.append("Created Service Connection for docker") | ||||||
|             DockerServiceConnection.objects.create( |             DockerServiceConnection.objects.create( | ||||||
|                 name="Local Docker connection", |                 name="Local Docker connection", | ||||||
|                 local=True, |                 local=True, | ||||||
|                 url=unix_socket_path, |                 url=unix_socket_path, | ||||||
|             ) |             ) | ||||||
|  |     self.set_status(TaskStatus.SUCCESSFUL, *messages) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Terminate session on all outposts.")) | @CELERY_APP.task() | ||||||
| def outpost_session_end(session_id: str): | def outpost_session_end(session_id: str): | ||||||
|  |     """Update outpost instances connected to a single outpost""" | ||||||
|     layer = get_channel_layer() |     layer = get_channel_layer() | ||||||
|     hashed_session_id = hash_session_key(session_id) |     hashed_session_id = hash_session_key(session_id) | ||||||
|     for outpost in Outpost.objects.all(): |     for outpost in Outpost.objects.all(): | ||||||
|  | |||||||
| @ -37,7 +37,6 @@ class OutpostTests(TestCase): | |||||||
|  |  | ||||||
|         # We add a provider, user should only have access to outpost and provider |         # We add a provider, user should only have access to outpost and provider | ||||||
|         outpost.providers.add(provider) |         outpost.providers.add(provider) | ||||||
|         provider.refresh_from_db() |  | ||||||
|         permissions = UserObjectPermission.objects.filter(user=outpost.user).order_by( |         permissions = UserObjectPermission.objects.filter(user=outpost.user).order_by( | ||||||
|             "content_type__model" |             "content_type__model" | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -15,7 +15,6 @@ class AuthentikProviderProxyConfig(ManagedAppConfig): | |||||||
|     def proxy_set_defaults(self): |     def proxy_set_defaults(self): | ||||||
|         from authentik.providers.proxy.models import ProxyProvider |         from authentik.providers.proxy.models import ProxyProvider | ||||||
|  |  | ||||||
|         # TODO: figure out if this can be in pre_save + post_save signals |  | ||||||
|         for provider in ProxyProvider.objects.all(): |         for provider in ProxyProvider.objects.all(): | ||||||
|             provider.set_oauth_defaults() |             provider.set_oauth_defaults() | ||||||
|             provider.save() |             provider.save() | ||||||
|  | |||||||
| @ -1,13 +0,0 @@ | |||||||
| """Proxy provider signals""" |  | ||||||
|  |  | ||||||
| from django.db.models.signals import pre_delete |  | ||||||
| from django.dispatch import receiver |  | ||||||
|  |  | ||||||
| from authentik.core.models import AuthenticatedSession |  | ||||||
| from authentik.providers.proxy.tasks import proxy_on_logout |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_delete, sender=AuthenticatedSession) |  | ||||||
| def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): |  | ||||||
|     """Catch logout by expiring sessions being deleted""" |  | ||||||
|     proxy_on_logout.send(instance.session.session_key) |  | ||||||
| @ -1,26 +0,0 @@ | |||||||
| """proxy provider tasks""" |  | ||||||
|  |  | ||||||
| from asgiref.sync import async_to_sync |  | ||||||
| from channels.layers import get_channel_layer |  | ||||||
| from django.utils.translation import gettext_lazy as _ |  | ||||||
| from dramatiq.actor import actor |  | ||||||
|  |  | ||||||
| from authentik.outposts.consumer import OUTPOST_GROUP |  | ||||||
| from authentik.outposts.models import Outpost, OutpostType |  | ||||||
| from authentik.providers.oauth2.id_token import hash_session_key |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Terminate session on Proxy outpost.")) |  | ||||||
| def proxy_on_logout(session_id: str): |  | ||||||
|     layer = get_channel_layer() |  | ||||||
|     hashed_session_id = hash_session_key(session_id) |  | ||||||
|     for outpost in Outpost.objects.filter(type=OutpostType.PROXY): |  | ||||||
|         group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} |  | ||||||
|         async_to_sync(layer.group_send)( |  | ||||||
|             group, |  | ||||||
|             { |  | ||||||
|                 "type": "event.provider.specific", |  | ||||||
|                 "sub_type": "logout", |  | ||||||
|                 "session_id": hashed_session_id, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
| @ -17,7 +17,6 @@ from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User | |||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.lib.models import SerializerModel | from authentik.lib.models import SerializerModel | ||||||
| from authentik.lib.utils.time import timedelta_string_validator | from authentik.lib.utils.time import timedelta_string_validator | ||||||
| from authentik.outposts.models import OutpostModel |  | ||||||
| from authentik.policies.models import PolicyBindingModel | from authentik.policies.models import PolicyBindingModel | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -38,7 +37,7 @@ class AuthenticationMode(models.TextChoices): | |||||||
|     PROMPT = "prompt" |     PROMPT = "prompt" | ||||||
|  |  | ||||||
|  |  | ||||||
| class RACProvider(OutpostModel, Provider): | class RACProvider(Provider): | ||||||
|     """Remotely access computers/servers via RDP/SSH/VNC.""" |     """Remotely access computers/servers via RDP/SSH/VNC.""" | ||||||
|  |  | ||||||
|     settings = models.JSONField(default=dict) |     settings = models.JSONField(default=dict) | ||||||
|  | |||||||
| @ -44,5 +44,5 @@ class SCIMProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelVie | |||||||
|     filterset_fields = ["name", "exclude_users_service_account", "url", "filter_group"] |     filterset_fields = ["name", "exclude_users_service_account", "url", "filter_group"] | ||||||
|     search_fields = ["name", "url"] |     search_fields = ["name", "url"] | ||||||
|     ordering = ["name", "url"] |     ordering = ["name", "url"] | ||||||
|     sync_task = scim_sync |     sync_single_task = scim_sync | ||||||
|     sync_objects_task = scim_sync_objects |     sync_objects_task = scim_sync_objects | ||||||
|  | |||||||
| @ -3,6 +3,7 @@ | |||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.providers.scim.models import SCIMProvider | from authentik.providers.scim.models import SCIMProvider | ||||||
|  | from authentik.providers.scim.tasks import scim_sync, sync_tasks | ||||||
| from authentik.tenants.management import TenantCommand | from authentik.tenants.management import TenantCommand | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -20,5 +21,4 @@ class Command(TenantCommand): | |||||||
|             if not provider: |             if not provider: | ||||||
|                 LOGGER.warning("Provider does not exist", name=provider_name) |                 LOGGER.warning("Provider does not exist", name=provider_name) | ||||||
|                 continue |                 continue | ||||||
|             for schedule in provider.schedules.all(): |             sync_tasks.trigger_single_task(provider, scim_sync).get() | ||||||
|                 schedule.send().get_result() |  | ||||||
|  | |||||||
| @ -7,7 +7,6 @@ from django.db import models | |||||||
| from django.db.models import QuerySet | from django.db.models import QuerySet | ||||||
| from django.templatetags.static import static | from django.templatetags.static import static | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from dramatiq.actor import Actor |  | ||||||
| from rest_framework.serializers import Serializer | from rest_framework.serializers import Serializer | ||||||
|  |  | ||||||
| from authentik.core.models import BackchannelProvider, Group, PropertyMapping, User, UserTypes | from authentik.core.models import BackchannelProvider, Group, PropertyMapping, User, UserTypes | ||||||
| @ -100,12 +99,6 @@ class SCIMProvider(OutgoingSyncProvider, BackchannelProvider): | |||||||
|     def icon_url(self) -> str | None: |     def icon_url(self) -> str | None: | ||||||
|         return static("authentik/sources/scim.png") |         return static("authentik/sources/scim.png") | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def sync_actor(self) -> Actor: |  | ||||||
|         from authentik.providers.scim.tasks import scim_sync |  | ||||||
|  |  | ||||||
|         return scim_sync |  | ||||||
|  |  | ||||||
|     def client_for_model( |     def client_for_model( | ||||||
|         self, model: type[User | Group | SCIMProviderUser | SCIMProviderGroup] |         self, model: type[User | Group | SCIMProviderUser | SCIMProviderGroup] | ||||||
|     ) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]: |     ) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]: | ||||||
|  | |||||||
							
								
								
									
										13
									
								
								authentik/providers/scim/settings.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								authentik/providers/scim/settings.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,13 @@ | |||||||
|  | """SCIM task Settings""" | ||||||
|  |  | ||||||
|  | from celery.schedules import crontab | ||||||
|  |  | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
|  |  | ||||||
|  | CELERY_BEAT_SCHEDULE = { | ||||||
|  |     "providers_scim_sync": { | ||||||
|  |         "task": "authentik.providers.scim.tasks.scim_sync_all", | ||||||
|  |         "schedule": crontab(minute=fqdn_rand("scim_sync_all"), hour="*/4"), | ||||||
|  |         "options": {"queue": "authentik_scheduled"}, | ||||||
|  |     }, | ||||||
|  | } | ||||||
| @ -2,10 +2,11 @@ | |||||||
|  |  | ||||||
| from authentik.lib.sync.outgoing.signals import register_signals | from authentik.lib.sync.outgoing.signals import register_signals | ||||||
| from authentik.providers.scim.models import SCIMProvider | from authentik.providers.scim.models import SCIMProvider | ||||||
| from authentik.providers.scim.tasks import scim_sync_direct_dispatch, scim_sync_m2m_dispatch | from authentik.providers.scim.tasks import scim_sync, scim_sync_direct, scim_sync_m2m | ||||||
|  |  | ||||||
| register_signals( | register_signals( | ||||||
|     SCIMProvider, |     SCIMProvider, | ||||||
|     task_sync_direct_dispatch=scim_sync_direct_dispatch, |     task_sync_single=scim_sync, | ||||||
|     task_sync_m2m_dispatch=scim_sync_m2m_dispatch, |     task_sync_direct=scim_sync_direct, | ||||||
|  |     task_sync_m2m=scim_sync_m2m, | ||||||
| ) | ) | ||||||
|  | |||||||
| @ -1,40 +1,37 @@ | |||||||
| """SCIM Provider tasks""" | """SCIM Provider tasks""" | ||||||
|  |  | ||||||
| from django.utils.translation import gettext_lazy as _ | from authentik.events.system_tasks import SystemTask | ||||||
| from dramatiq.actor import actor | from authentik.lib.sync.outgoing.exceptions import TransientSyncException | ||||||
|  |  | ||||||
| from authentik.lib.sync.outgoing.tasks import SyncTasks | from authentik.lib.sync.outgoing.tasks import SyncTasks | ||||||
| from authentik.providers.scim.models import SCIMProvider | from authentik.providers.scim.models import SCIMProvider | ||||||
|  | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| sync_tasks = SyncTasks(SCIMProvider) | sync_tasks = SyncTasks(SCIMProvider) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Sync SCIM provider objects.")) | @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | ||||||
| def scim_sync_objects(*args, **kwargs): | def scim_sync_objects(*args, **kwargs): | ||||||
|     return sync_tasks.sync_objects(*args, **kwargs) |     return sync_tasks.sync_objects(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Full sync for SCIM provider.")) | @CELERY_APP.task( | ||||||
| def scim_sync(provider_pk: int, *args, **kwargs): |     base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True | ||||||
|  | ) | ||||||
|  | def scim_sync(self, provider_pk: int, *args, **kwargs): | ||||||
|     """Run full sync for SCIM provider""" |     """Run full sync for SCIM provider""" | ||||||
|     return sync_tasks.sync(provider_pk, scim_sync_objects) |     return sync_tasks.sync_single(self, provider_pk, scim_sync_objects) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Sync a direct object (user, group) for SCIM provider.")) | @CELERY_APP.task() | ||||||
|  | def scim_sync_all(): | ||||||
|  |     return sync_tasks.sync_all(scim_sync) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | ||||||
| def scim_sync_direct(*args, **kwargs): | def scim_sync_direct(*args, **kwargs): | ||||||
|     return sync_tasks.sync_signal_direct(*args, **kwargs) |     return sync_tasks.sync_signal_direct(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Dispatch syncs for a direct object (user, group) for SCIM providers.")) | @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | ||||||
| def scim_sync_direct_dispatch(*args, **kwargs): |  | ||||||
|     return sync_tasks.sync_signal_direct_dispatch(scim_sync_direct, *args, **kwargs) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Sync a related object (memberships) for SCIM provider.")) |  | ||||||
| def scim_sync_m2m(*args, **kwargs): | def scim_sync_m2m(*args, **kwargs): | ||||||
|     return sync_tasks.sync_signal_m2m(*args, **kwargs) |     return sync_tasks.sync_signal_m2m(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @actor(description=_("Dispatch syncs for a related object (memberships) for SCIM providers.")) |  | ||||||
| def scim_sync_m2m_dispatch(*args, **kwargs): |  | ||||||
|     return sync_tasks.sync_signal_m2m_dispatch(scim_sync_m2m, *args, **kwargs) |  | ||||||
|  | |||||||
| @ -8,7 +8,7 @@ from authentik.core.models import Application | |||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.providers.scim.clients.base import SCIMClient | from authentik.providers.scim.clients.base import SCIMClient | ||||||
| from authentik.providers.scim.models import SCIMMapping, SCIMProvider | from authentik.providers.scim.models import SCIMMapping, SCIMProvider | ||||||
| from authentik.providers.scim.tasks import scim_sync | from authentik.providers.scim.tasks import scim_sync_all | ||||||
|  |  | ||||||
|  |  | ||||||
| class SCIMClientTests(TestCase): | class SCIMClientTests(TestCase): | ||||||
| @ -85,6 +85,6 @@ class SCIMClientTests(TestCase): | |||||||
|             self.assertEqual(mock.call_count, 1) |             self.assertEqual(mock.call_count, 1) | ||||||
|             self.assertEqual(mock.request_history[0].method, "GET") |             self.assertEqual(mock.request_history[0].method, "GET") | ||||||
|  |  | ||||||
|     def test_scim_sync(self): |     def test_scim_sync_all(self): | ||||||
|         """test scim_sync task""" |         """test scim_sync_all task""" | ||||||
|         scim_sync.send(self.provider.pk).get_result() |         scim_sync_all() | ||||||
|  | |||||||
| @ -8,7 +8,7 @@ from authentik.core.models import Application, Group, User | |||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.providers.scim.clients.schema import ServiceProviderConfiguration | from authentik.providers.scim.clients.schema import ServiceProviderConfiguration | ||||||
| from authentik.providers.scim.models import SCIMMapping, SCIMProvider | from authentik.providers.scim.models import SCIMMapping, SCIMProvider | ||||||
| from authentik.providers.scim.tasks import scim_sync | from authentik.providers.scim.tasks import scim_sync, sync_tasks | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -79,15 +79,17 @@ class SCIMMembershipTests(TestCase): | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             self.configure() |             self.configure() | ||||||
|             scim_sync.send(self.provider.pk) |             sync_tasks.trigger_single_task(self.provider, scim_sync).get() | ||||||
|  |  | ||||||
|             self.assertEqual(mocker.call_count, 4) |             self.assertEqual(mocker.call_count, 6) | ||||||
|             self.assertEqual(mocker.request_history[0].method, "GET") |             self.assertEqual(mocker.request_history[0].method, "GET") | ||||||
|             self.assertEqual(mocker.request_history[1].method, "POST") |             self.assertEqual(mocker.request_history[1].method, "GET") | ||||||
|             self.assertEqual(mocker.request_history[2].method, "GET") |             self.assertEqual(mocker.request_history[2].method, "GET") | ||||||
|             self.assertEqual(mocker.request_history[3].method, "POST") |             self.assertEqual(mocker.request_history[3].method, "POST") | ||||||
|  |             self.assertEqual(mocker.request_history[4].method, "GET") | ||||||
|  |             self.assertEqual(mocker.request_history[5].method, "POST") | ||||||
|             self.assertJSONEqual( |             self.assertJSONEqual( | ||||||
|                 mocker.request_history[1].body, |                 mocker.request_history[3].body, | ||||||
|                 { |                 { | ||||||
|                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], |                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], | ||||||
|                     "emails": [], |                     "emails": [], | ||||||
| @ -99,7 +101,7 @@ class SCIMMembershipTests(TestCase): | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|             self.assertJSONEqual( |             self.assertJSONEqual( | ||||||
|                 mocker.request_history[3].body, |                 mocker.request_history[5].body, | ||||||
|                 { |                 { | ||||||
|                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], |                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], | ||||||
|                     "externalId": str(group.pk), |                     "externalId": str(group.pk), | ||||||
| @ -167,15 +169,17 @@ class SCIMMembershipTests(TestCase): | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             self.configure() |             self.configure() | ||||||
|             scim_sync.send(self.provider.pk) |             sync_tasks.trigger_single_task(self.provider, scim_sync).get() | ||||||
|  |  | ||||||
|             self.assertEqual(mocker.call_count, 4) |             self.assertEqual(mocker.call_count, 6) | ||||||
|             self.assertEqual(mocker.request_history[0].method, "GET") |             self.assertEqual(mocker.request_history[0].method, "GET") | ||||||
|             self.assertEqual(mocker.request_history[1].method, "POST") |             self.assertEqual(mocker.request_history[1].method, "GET") | ||||||
|             self.assertEqual(mocker.request_history[2].method, "GET") |             self.assertEqual(mocker.request_history[2].method, "GET") | ||||||
|             self.assertEqual(mocker.request_history[3].method, "POST") |             self.assertEqual(mocker.request_history[3].method, "POST") | ||||||
|  |             self.assertEqual(mocker.request_history[4].method, "GET") | ||||||
|  |             self.assertEqual(mocker.request_history[5].method, "POST") | ||||||
|             self.assertJSONEqual( |             self.assertJSONEqual( | ||||||
|                 mocker.request_history[1].body, |                 mocker.request_history[3].body, | ||||||
|                 { |                 { | ||||||
|                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], |                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], | ||||||
|                     "active": True, |                     "active": True, | ||||||
| @ -187,7 +191,7 @@ class SCIMMembershipTests(TestCase): | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|             self.assertJSONEqual( |             self.assertJSONEqual( | ||||||
|                 mocker.request_history[3].body, |                 mocker.request_history[5].body, | ||||||
|                 { |                 { | ||||||
|                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], |                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], | ||||||
|                     "externalId": str(group.pk), |                     "externalId": str(group.pk), | ||||||
| @ -283,15 +287,17 @@ class SCIMMembershipTests(TestCase): | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             self.configure() |             self.configure() | ||||||
|             scim_sync.send(self.provider.pk) |             sync_tasks.trigger_single_task(self.provider, scim_sync).get() | ||||||
|  |  | ||||||
|             self.assertEqual(mocker.call_count, 4) |             self.assertEqual(mocker.call_count, 6) | ||||||
|             self.assertEqual(mocker.request_history[0].method, "GET") |             self.assertEqual(mocker.request_history[0].method, "GET") | ||||||
|             self.assertEqual(mocker.request_history[1].method, "POST") |             self.assertEqual(mocker.request_history[1].method, "GET") | ||||||
|             self.assertEqual(mocker.request_history[2].method, "GET") |             self.assertEqual(mocker.request_history[2].method, "GET") | ||||||
|             self.assertEqual(mocker.request_history[3].method, "POST") |             self.assertEqual(mocker.request_history[3].method, "POST") | ||||||
|  |             self.assertEqual(mocker.request_history[4].method, "GET") | ||||||
|  |             self.assertEqual(mocker.request_history[5].method, "POST") | ||||||
|             self.assertJSONEqual( |             self.assertJSONEqual( | ||||||
|                 mocker.request_history[1].body, |                 mocker.request_history[3].body, | ||||||
|                 { |                 { | ||||||
|                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], |                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], | ||||||
|                     "emails": [], |                     "emails": [], | ||||||
| @ -303,7 +309,7 @@ class SCIMMembershipTests(TestCase): | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|             self.assertJSONEqual( |             self.assertJSONEqual( | ||||||
|                 mocker.request_history[3].body, |                 mocker.request_history[5].body, | ||||||
|                 { |                 { | ||||||
|                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], |                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], | ||||||
|                     "externalId": str(group.pk), |                     "externalId": str(group.pk), | ||||||
|  | |||||||
| @ -9,11 +9,11 @@ from requests_mock import Mocker | |||||||
|  |  | ||||||
| from authentik.blueprints.tests import apply_blueprint | from authentik.blueprints.tests import apply_blueprint | ||||||
| from authentik.core.models import Application, Group, User | from authentik.core.models import Application, Group, User | ||||||
|  | from authentik.events.models import SystemTask | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.lib.sync.outgoing.base import SAFE_METHODS | from authentik.lib.sync.outgoing.base import SAFE_METHODS | ||||||
| from authentik.providers.scim.models import SCIMMapping, SCIMProvider | from authentik.providers.scim.models import SCIMMapping, SCIMProvider | ||||||
| from authentik.providers.scim.tasks import scim_sync, scim_sync_objects | from authentik.providers.scim.tasks import scim_sync, sync_tasks | ||||||
| from authentik.tasks.models import Task |  | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -356,7 +356,7 @@ class SCIMUserTests(TestCase): | |||||||
|             email=f"{uid}@goauthentik.io", |             email=f"{uid}@goauthentik.io", | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         scim_sync.send(self.provider.pk) |         sync_tasks.trigger_single_task(self.provider, scim_sync).get() | ||||||
|  |  | ||||||
|         self.assertEqual(mock.call_count, 5) |         self.assertEqual(mock.call_count, 5) | ||||||
|         self.assertEqual(mock.request_history[0].method, "GET") |         self.assertEqual(mock.request_history[0].method, "GET") | ||||||
| @ -428,19 +428,14 @@ class SCIMUserTests(TestCase): | |||||||
|                 email=f"{uid}@goauthentik.io", |                 email=f"{uid}@goauthentik.io", | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             scim_sync.send(self.provider.pk) |             sync_tasks.trigger_single_task(self.provider, scim_sync).get() | ||||||
|  |  | ||||||
|             self.assertEqual(mock.call_count, 3) |             self.assertEqual(mock.call_count, 3) | ||||||
|             for request in mock.request_history: |             for request in mock.request_history: | ||||||
|                 self.assertIn(request.method, SAFE_METHODS) |                 self.assertIn(request.method, SAFE_METHODS) | ||||||
|         task = list( |         task = SystemTask.objects.filter(uid=slugify(self.provider.name)).first() | ||||||
|             Task.objects.filter( |  | ||||||
|                 actor_name=scim_sync_objects.actor_name, |  | ||||||
|                 _uid=slugify(self.provider.name), |  | ||||||
|             ).order_by("-mtime") |  | ||||||
|         )[1] |  | ||||||
|         self.assertIsNotNone(task) |         self.assertIsNotNone(task) | ||||||
|         drop_msg = task._messages[3] |         drop_msg = task.messages[3] | ||||||
|         self.assertEqual(drop_msg["event"], "Dropping mutating request due to dry run") |         self.assertEqual(drop_msg["event"], "Dropping mutating request due to dry run") | ||||||
|         self.assertIsNotNone(drop_msg["attributes"]["url"]) |         self.assertIsNotNone(drop_msg["attributes"]["url"]) | ||||||
|         self.assertIsNotNone(drop_msg["attributes"]["body"]) |         self.assertIsNotNone(drop_msg["attributes"]["body"]) | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| """authentik core celery""" | """authentik core celery""" | ||||||
| 
 | 
 | ||||||
| import os | import os | ||||||
|  | from collections.abc import Callable | ||||||
| from contextvars import ContextVar | from contextvars import ContextVar | ||||||
| from logging.config import dictConfig | from logging.config import dictConfig | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| @ -15,9 +16,12 @@ from celery.signals import ( | |||||||
|     task_internal_error, |     task_internal_error, | ||||||
|     task_postrun, |     task_postrun, | ||||||
|     task_prerun, |     task_prerun, | ||||||
|  |     worker_ready, | ||||||
| ) | ) | ||||||
| from celery.worker.control import inspect_command | from celery.worker.control import inspect_command | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
|  | from django.db import ProgrammingError | ||||||
|  | from django_tenants.utils import get_public_schema_name | ||||||
| from structlog.contextvars import STRUCTLOG_KEY_PREFIX | from structlog.contextvars import STRUCTLOG_KEY_PREFIX | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp | from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp | ||||||
| @ -65,10 +69,7 @@ def task_postrun_hook(task_id: str, task, *args, retval=None, state=None, **kwar | |||||||
|     """Log task_id on worker""" |     """Log task_id on worker""" | ||||||
|     CTX_TASK_ID.set(...) |     CTX_TASK_ID.set(...) | ||||||
|     LOGGER.info( |     LOGGER.info( | ||||||
|         "Task finished", |         "Task finished", task_id=task_id.replace("-", ""), task_name=task.__name__, state=state | ||||||
|         task_id=task_id.replace("-", ""), |  | ||||||
|         task_name=task.__name__, |  | ||||||
|         state=state, |  | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -82,12 +83,51 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar | |||||||
|     CTX_TASK_ID.set(...) |     CTX_TASK_ID.set(...) | ||||||
|     if not should_ignore_exception(exception): |     if not should_ignore_exception(exception): | ||||||
|         Event.new( |         Event.new( | ||||||
|             EventAction.SYSTEM_EXCEPTION, |             EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id | ||||||
|             message=exception_to_string(exception), |  | ||||||
|             task_id=task_id, |  | ||||||
|         ).save() |         ).save() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def _get_startup_tasks_default_tenant() -> list[Callable]: | ||||||
|  |     """Get all tasks to be run on startup for the default tenant""" | ||||||
|  |     from authentik.outposts.tasks import outpost_connection_discovery | ||||||
|  | 
 | ||||||
|  |     return [ | ||||||
|  |         outpost_connection_discovery, | ||||||
|  |     ] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _get_startup_tasks_all_tenants() -> list[Callable]: | ||||||
|  |     """Get all tasks to be run on startup for all tenants""" | ||||||
|  |     return [] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @worker_ready.connect | ||||||
|  | def worker_ready_hook(*args, **kwargs): | ||||||
|  |     """Run certain tasks on worker start""" | ||||||
|  |     from authentik.tenants.models import Tenant | ||||||
|  | 
 | ||||||
|  |     LOGGER.info("Dispatching startup tasks...") | ||||||
|  | 
 | ||||||
|  |     def _run_task(task: Callable): | ||||||
|  |         try: | ||||||
|  |             task.delay() | ||||||
|  |         except ProgrammingError as exc: | ||||||
|  |             LOGGER.warning("Startup task failed", task=task, exc=exc) | ||||||
|  | 
 | ||||||
|  |     for task in _get_startup_tasks_default_tenant(): | ||||||
|  |         with Tenant.objects.get(schema_name=get_public_schema_name()): | ||||||
|  |             _run_task(task) | ||||||
|  | 
 | ||||||
|  |     for task in _get_startup_tasks_all_tenants(): | ||||||
|  |         for tenant in Tenant.objects.filter(ready=True): | ||||||
|  |             with tenant: | ||||||
|  |                 _run_task(task) | ||||||
|  | 
 | ||||||
|  |     from authentik.blueprints.v1.tasks import start_blueprint_watcher | ||||||
|  | 
 | ||||||
|  |     start_blueprint_watcher() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class LivenessProbe(bootsteps.StartStopStep): | class LivenessProbe(bootsteps.StartStopStep): | ||||||
|     """Add a timed task to touch a temporary file for healthchecking reasons""" |     """Add a timed task to touch a temporary file for healthchecking reasons""" | ||||||
| 
 | 
 | ||||||
| @ -4,9 +4,9 @@ import importlib | |||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| from hashlib import sha512 | from hashlib import sha512 | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from tempfile import gettempdir |  | ||||||
|  |  | ||||||
| import orjson | import orjson | ||||||
|  | from celery.schedules import crontab | ||||||
| from sentry_sdk import set_tag | from sentry_sdk import set_tag | ||||||
| from xmlsec import enable_debug_trace | from xmlsec import enable_debug_trace | ||||||
|  |  | ||||||
| @ -65,18 +65,14 @@ SHARED_APPS = [ | |||||||
|     "pgactivity", |     "pgactivity", | ||||||
|     "pglock", |     "pglock", | ||||||
|     "channels", |     "channels", | ||||||
|     "django_dramatiq_postgres", |  | ||||||
|     "authentik.tasks", |  | ||||||
| ] | ] | ||||||
| TENANT_APPS = [ | TENANT_APPS = [ | ||||||
|     "django.contrib.auth", |     "django.contrib.auth", | ||||||
|     "django.contrib.contenttypes", |     "django.contrib.contenttypes", | ||||||
|     "django.contrib.sessions", |     "django.contrib.sessions", | ||||||
|     "pgtrigger", |  | ||||||
|     "authentik.admin", |     "authentik.admin", | ||||||
|     "authentik.api", |     "authentik.api", | ||||||
|     "authentik.crypto", |     "authentik.crypto", | ||||||
|     "authentik.events", |  | ||||||
|     "authentik.flows", |     "authentik.flows", | ||||||
|     "authentik.outposts", |     "authentik.outposts", | ||||||
|     "authentik.policies.dummy", |     "authentik.policies.dummy", | ||||||
| @ -124,7 +120,6 @@ TENANT_APPS = [ | |||||||
|     "authentik.stages.user_login", |     "authentik.stages.user_login", | ||||||
|     "authentik.stages.user_logout", |     "authentik.stages.user_logout", | ||||||
|     "authentik.stages.user_write", |     "authentik.stages.user_write", | ||||||
|     "authentik.tasks.schedules", |  | ||||||
|     "authentik.brands", |     "authentik.brands", | ||||||
|     "authentik.blueprints", |     "authentik.blueprints", | ||||||
|     "guardian", |     "guardian", | ||||||
| @ -171,7 +166,6 @@ SPECTACULAR_SETTINGS = { | |||||||
|         "UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification", |         "UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification", | ||||||
|         "UserTypeEnum": "authentik.core.models.UserTypes", |         "UserTypeEnum": "authentik.core.models.UserTypes", | ||||||
|         "OutgoingSyncDeleteAction": "authentik.lib.sync.outgoing.models.OutgoingSyncDeleteAction", |         "OutgoingSyncDeleteAction": "authentik.lib.sync.outgoing.models.OutgoingSyncDeleteAction", | ||||||
|         "TaskAggregatedStatusEnum": "authentik.tasks.models.TaskStatus", |  | ||||||
|     }, |     }, | ||||||
|     "ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False, |     "ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False, | ||||||
|     "ENUM_GENERATE_CHOICE_DESCRIPTION": False, |     "ENUM_GENERATE_CHOICE_DESCRIPTION": False, | ||||||
| @ -347,85 +341,37 @@ USE_TZ = True | |||||||
|  |  | ||||||
| LOCALE_PATHS = ["./locale"] | LOCALE_PATHS = ["./locale"] | ||||||
|  |  | ||||||
|  | CELERY = { | ||||||
| # Tests |     "task_soft_time_limit": 600, | ||||||
|  |     "worker_max_tasks_per_child": 50, | ||||||
| TEST = False |     "worker_concurrency": CONFIG.get_int("worker.concurrency"), | ||||||
| TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner" |     "beat_schedule": { | ||||||
|  |         "clean_expired_models": { | ||||||
|  |             "task": "authentik.core.tasks.clean_expired_models", | ||||||
| # Dramatiq |             "schedule": crontab(minute="2-59/5"), | ||||||
|  |             "options": {"queue": "authentik_scheduled"}, | ||||||
| DRAMATIQ = { |  | ||||||
|     "broker_class": "authentik.tasks.broker.Broker", |  | ||||||
|     "channel_prefix": "authentik", |  | ||||||
|     "task_model": "authentik.tasks.models.Task", |  | ||||||
|     "task_purge_interval": timedelta_from_string( |  | ||||||
|         CONFIG.get("worker.task_purge_interval") |  | ||||||
|     ).total_seconds(), |  | ||||||
|     "task_expiration": timedelta_from_string(CONFIG.get("worker.task_expiration")).total_seconds(), |  | ||||||
|     "autodiscovery": { |  | ||||||
|         "enabled": True, |  | ||||||
|         "setup_module": "authentik.tasks.setup", |  | ||||||
|         "apps_prefix": "authentik", |  | ||||||
|         }, |         }, | ||||||
|     "worker": { |         "user_cleanup": { | ||||||
|         "processes": CONFIG.get_int("worker.processes", 2), |             "task": "authentik.core.tasks.clean_temporary_users", | ||||||
|         "threads": CONFIG.get_int("worker.threads", 1), |             "schedule": crontab(minute="9-59/5"), | ||||||
|         "consumer_listen_timeout": timedelta_from_string( |             "options": {"queue": "authentik_scheduled"}, | ||||||
|             CONFIG.get("worker.consumer_listen_timeout") |  | ||||||
|         ).total_seconds(), |  | ||||||
|         "watch_folder": BASE_DIR / "authentik", |  | ||||||
|         }, |         }, | ||||||
|     "scheduler_class": "authentik.tasks.schedules.scheduler.Scheduler", |  | ||||||
|     "schedule_model": "authentik.tasks.schedules.models.Schedule", |  | ||||||
|     "scheduler_interval": timedelta_from_string( |  | ||||||
|         CONFIG.get("worker.scheduler_interval") |  | ||||||
|     ).total_seconds(), |  | ||||||
|     "middlewares": ( |  | ||||||
|         ("django_dramatiq_postgres.middleware.FullyQualifiedActorName", {}), |  | ||||||
|         # TODO: fixme |  | ||||||
|         # ("dramatiq.middleware.prometheus.Prometheus", {}), |  | ||||||
|         ("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}), |  | ||||||
|         ("dramatiq.middleware.age_limit.AgeLimit", {}), |  | ||||||
|         ( |  | ||||||
|             "dramatiq.middleware.time_limit.TimeLimit", |  | ||||||
|             { |  | ||||||
|                 "time_limit": timedelta_from_string( |  | ||||||
|                     CONFIG.get("worker.task_default_time_limit") |  | ||||||
|                 ).total_seconds() |  | ||||||
|                 * 1000 |  | ||||||
|     }, |     }, | ||||||
|  |     "beat_scheduler": "authentik.tenants.scheduler:TenantAwarePersistentScheduler", | ||||||
|  |     "task_create_missing_queues": True, | ||||||
|  |     "task_default_queue": "authentik", | ||||||
|  |     "broker_url": CONFIG.get("broker.url") or redis_url(CONFIG.get("redis.db")), | ||||||
|  |     "result_backend": CONFIG.get("result_backend.url") or redis_url(CONFIG.get("redis.db")), | ||||||
|  |     "broker_transport_options": CONFIG.get_dict_from_b64_json( | ||||||
|  |         "broker.transport_options", {"retry_policy": {"timeout": 5.0}} | ||||||
|     ), |     ), | ||||||
|         ("dramatiq.middleware.shutdown.ShutdownNotifications", {}), |     "result_backend_transport_options": CONFIG.get_dict_from_b64_json( | ||||||
|         ("dramatiq.middleware.callbacks.Callbacks", {}), |         "result_backend.transport_options", {"retry_policy": {"timeout": 5.0}} | ||||||
|         ("dramatiq.middleware.pipelines.Pipelines", {}), |  | ||||||
|         ( |  | ||||||
|             "dramatiq.middleware.retries.Retries", |  | ||||||
|             {"max_retries": CONFIG.get_int("worker.task_max_retries") if not TEST else 0}, |  | ||||||
|     ), |     ), | ||||||
|         ("dramatiq.results.middleware.Results", {"store_results": True}), |     "redis_retry_on_timeout": True, | ||||||
|         ("django_dramatiq_postgres.middleware.CurrentTask", {}), |  | ||||||
|         ("authentik.tasks.middleware.TenantMiddleware", {}), |  | ||||||
|         ("authentik.tasks.middleware.RelObjMiddleware", {}), |  | ||||||
|         ("authentik.tasks.middleware.MessagesMiddleware", {}), |  | ||||||
|         ("authentik.tasks.middleware.LoggingMiddleware", {}), |  | ||||||
|         ("authentik.tasks.middleware.DescriptionMiddleware", {}), |  | ||||||
|         ("authentik.tasks.middleware.WorkerStatusMiddleware", {}), |  | ||||||
|         ( |  | ||||||
|             "authentik.tasks.middleware.MetricsMiddleware", |  | ||||||
|             { |  | ||||||
|                 "multiproc_dir": str(Path(gettempdir()) / "authentik_prometheus_tmp"), |  | ||||||
|                 "prefix": "authentik", |  | ||||||
|             }, |  | ||||||
|         ), |  | ||||||
|     ), |  | ||||||
|     "test": TEST, |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| # Sentry integration | # Sentry integration | ||||||
|  |  | ||||||
| env = get_env() | env = get_env() | ||||||
| _ERROR_REPORTING = CONFIG.get_bool("error_reporting.enabled", False) | _ERROR_REPORTING = CONFIG.get_bool("error_reporting.enabled", False) | ||||||
| if _ERROR_REPORTING: | if _ERROR_REPORTING: | ||||||
| @ -486,6 +432,9 @@ else: | |||||||
|     MEDIA_ROOT = STORAGES["default"]["OPTIONS"]["location"] |     MEDIA_ROOT = STORAGES["default"]["OPTIONS"]["location"] | ||||||
|     MEDIA_URL = STORAGES["default"]["OPTIONS"]["base_url"] |     MEDIA_URL = STORAGES["default"]["OPTIONS"]["base_url"] | ||||||
|  |  | ||||||
|  | TEST = False | ||||||
|  | TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner" | ||||||
|  |  | ||||||
| structlog_configure() | structlog_configure() | ||||||
| LOGGING = get_logger_config() | LOGGING = get_logger_config() | ||||||
|  |  | ||||||
| @ -496,6 +445,7 @@ _DISALLOWED_ITEMS = [ | |||||||
|     "INSTALLED_APPS", |     "INSTALLED_APPS", | ||||||
|     "MIDDLEWARE", |     "MIDDLEWARE", | ||||||
|     "AUTHENTICATION_BACKENDS", |     "AUTHENTICATION_BACKENDS", | ||||||
|  |     "CELERY", | ||||||
|     "SPECTACULAR_SETTINGS", |     "SPECTACULAR_SETTINGS", | ||||||
|     "REST_FRAMEWORK", |     "REST_FRAMEWORK", | ||||||
| ] | ] | ||||||
| @ -522,6 +472,7 @@ def _update_settings(app_path: str): | |||||||
|         AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", [])) |         AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", [])) | ||||||
|         SPECTACULAR_SETTINGS.update(getattr(settings_module, "SPECTACULAR_SETTINGS", {})) |         SPECTACULAR_SETTINGS.update(getattr(settings_module, "SPECTACULAR_SETTINGS", {})) | ||||||
|         REST_FRAMEWORK.update(getattr(settings_module, "REST_FRAMEWORK", {})) |         REST_FRAMEWORK.update(getattr(settings_module, "REST_FRAMEWORK", {})) | ||||||
|  |         CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {})) | ||||||
|         for _attr in dir(settings_module): |         for _attr in dir(settings_module): | ||||||
|             if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS: |             if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS: | ||||||
|                 globals()[_attr] = getattr(settings_module, _attr) |                 globals()[_attr] = getattr(settings_module, _attr) | ||||||
| @ -530,6 +481,7 @@ def _update_settings(app_path: str): | |||||||
|  |  | ||||||
|  |  | ||||||
| if DEBUG: | if DEBUG: | ||||||
|  |     CELERY["task_always_eager"] = True | ||||||
|     REST_FRAMEWORK["DEFAULT_RENDERER_CLASSES"].append( |     REST_FRAMEWORK["DEFAULT_RENDERER_CLASSES"].append( | ||||||
|         "rest_framework.renderers.BrowsableAPIRenderer" |         "rest_framework.renderers.BrowsableAPIRenderer" | ||||||
|     ) |     ) | ||||||
| @ -549,6 +501,10 @@ try: | |||||||
| except ImportError: | except ImportError: | ||||||
|     pass |     pass | ||||||
|  |  | ||||||
|  | # Import events after other apps since it relies on tasks and other things from all apps | ||||||
|  | # being imported for @prefill_task | ||||||
|  | TENANT_APPS.append("authentik.events") | ||||||
|  |  | ||||||
|  |  | ||||||
| # Load subapps's settings | # Load subapps's settings | ||||||
| for _app in set(SHARED_APPS + TENANT_APPS): | for _app in set(SHARED_APPS + TENANT_APPS): | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	![dependabot[bot]](/assets/img/avatar_default.png)