Compare commits
	
		
			249 Commits
		
	
	
		
			docusaurus
			...
			celery-2-d
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 80eb56b016 | |||
| 3a1d0fbd35 | |||
| c2dc38e804 | |||
| 9aad7e29a4 | |||
| 315f9073cc | |||
| c94fa13826 | |||
| 6b3fbb0abf | |||
| 06191d6cfc | |||
| c608dc5110 | |||
| bb05d6063d | |||
| c5b4a630c9 | |||
| 4a73c65710 | |||
| 51e645c26a | |||
| de3607b9a6 | |||
| fb5804b914 | |||
| 287f41fed8 | |||
| 172b595a9f | |||
| d61995d0e2 | |||
| 056460dac0 | |||
| b30b736a71 | |||
| d3aa43ced0 | |||
| 92c837f7b5 | |||
| 7fcd65e318 | |||
| 9c01c7d890 | |||
| bdb0564d4c | |||
| 09688f7a55 | |||
| 31bb995490 | |||
| 5d813438f0 | |||
| d92764789f | |||
| c702fa1f95 | |||
| 786bada7d0 | |||
| 690766d377 | |||
| 4990abdf4a | |||
| 76616cf1c5 | |||
| 48c0b5449e | |||
| 8956606564 | |||
| 0908d0b559 | |||
| 7e5c90fcdc | |||
| 4e210b8299 | |||
| bf65ca4c70 | |||
| 670a88659e | |||
| 0ebbaeea6f | |||
| 49a9911271 | |||
| f3c0ca1a59 | |||
| 64e7bff16c | |||
| f264989f9e | |||
| 8f4353181e | |||
| 813b7aa8ba | |||
| 3bb3e0d1ef | |||
| 101e5adeba | |||
| ae228c91e3 | |||
| 411c52491e | |||
| 85869806a2 | |||
| d132db475e | |||
| 13b5aa604b | |||
| 97a5acdff5 | |||
| ea38f2d120 | |||
| 94867aaebf | |||
| 0e67c1d818 | |||
| 2a460201bb | |||
| f99cb3e9fb | |||
| e4bd05f444 | |||
| 80c4eb9bef | |||
| 96b4d5aee4 | |||
| 6321537c8d | |||
| 43975ec231 | |||
| 9b13922fc2 | |||
| 031456629b | |||
| 2433ed1c9b | |||
| 9868d54320 | |||
| 747a3ed6e9 | |||
| 527e849ce2 | |||
| cfcd54ca19 | |||
| faed9cd66e | |||
| 897d0dbcbd | |||
| a12e991798 | |||
| e5b86c3578 | |||
| 07ff433134 | |||
| 21b3e0c8cb | |||
| cbdec236dd | |||
| 2509ccde1c | |||
| 7e7b33dba7 | |||
| 13e1e44626 | |||
| e634f23fc8 | |||
| 8554a8e0c5 | |||
| b80abffafc | |||
| 204f21699e | |||
| 0fd478fa3e | |||
| 7d7e47e972 | |||
| 92a33a408f | |||
| d18a54e9e6 | |||
| e6614a0705 | |||
| 4c491cf221 | |||
| 17434c84cf | |||
| 234fb2a0c6 | |||
| 00612f921d | |||
| 8b67015190 | |||
| 5a5176e21f | |||
| 8980282a02 | |||
| 2ca9edb1bc | |||
| 61d970cda4 | |||
| 16fd9cab67 | |||
| 8c7818a252 | |||
| 374779102a | |||
| 0ac854458a | |||
| 1cfaddf49d | |||
| 5ae69f5987 | |||
| c62e1d5457 | |||
| 5b8681b1af | |||
| e0dcade9ad | |||
| 1a6ab7f24b | |||
| 769844314c | |||
| e211604860 | |||
| 7ed711e8f0 | |||
| 196b276345 | |||
| 3c62c80ff1 | |||
| a031f1107a | |||
| 8f399bba3f | |||
| e354e877ea | |||
| f254b8cf8c | |||
| 814b06322a | |||
| 217063ef7b | |||
| c2f7883a5c | |||
| bd64c34787 | |||
| 7518d4391f | |||
| e67bd79c66 | |||
| 2fc6da53c1 | |||
| 250a98cf98 | |||
| f2926fa1eb | |||
| 5e2af4a740 | |||
| 41f2ca42cc | |||
| 7ef547b357 | |||
| 1a9c529e92 | |||
| 75d19bfe76 | |||
| 7f8f7376e0 | |||
| 7c49de9cba | |||
| 00ac9b6367 | |||
| 0e786f7040 | |||
| 03d363ba84 | |||
| 3f33519ec0 | |||
| cae03beb6d | |||
| e4c1e5aed0 | |||
| 5acdd67cba | |||
| 40dbac7a65 | |||
| 1b4ed02959 | |||
| a95e730cdb | |||
| d8c13159e1 | |||
| 5f951ca3ef | |||
| 338da72622 | |||
| 90debcdd70 | |||
| 3766ca86e8 | |||
| 59c8472628 | |||
| 293616e6b0 | |||
| f7305d58b1 | |||
| ba94f63705 | |||
| 06b2e0d14b | |||
| 80a5f44491 | |||
| aca0bde46d | |||
| e671811ad2 | |||
| 3140325493 | |||
| 6c0b879b30 | |||
| 0e0fb37dd7 | |||
| d2cacdc640 | |||
| e65fabf040 | |||
| 107b96e65c | |||
| 5d7ba51872 | |||
| 3037701a14 | |||
| 66f8377c79 | |||
| 86f81d92aa | |||
| 369437f2a1 | |||
| 4b8b80f1d4 | |||
| f839aef33a | |||
| eb87e30076 | |||
| 4302f91028 | |||
| b0af20b0d5 | |||
| 9b556cf4c4 | |||
| 7118219544 | |||
| 475600ea87 | |||
| 2139e0be05 | |||
| a43a0f77fb | |||
| 8a073e8c60 | |||
| 35640fcdfa | |||
| c62f73400a | |||
| c92cbd7e22 | |||
| c3b0d09e04 | |||
| c5a40fced3 | |||
| 9cc6ebabc1 | |||
| e89659fe71 | |||
| 3c1512028d | |||
| c7f80686de | |||
| 8a8386cfcb | |||
| e60165ee45 | |||
| bc6085adc7 | |||
| d413e2875c | |||
| 144986f48e | |||
| c9f1e34beb | |||
| 39f769b150 | |||
| 78180e376f | |||
| d51150102c | |||
| d5da16ad26 | |||
| 2b12e32fcf | |||
| a9b9661155 | |||
| f5f0cef275 | |||
| b756965511 | |||
| db900c4a42 | |||
| 72cb62085b | |||
| 5a42815850 | |||
| b9083a906a | |||
| 04be734c49 | |||
| 1ed6cf7517 | |||
| d6c4f97158 | |||
| 781704fa38 | |||
| 28f4d7d566 | |||
| 991778b2be | |||
| 9465dafd7d | |||
| 75c13a8801 | |||
| 8ae0f145f5 | |||
| 4d0e0e3afe | |||
| 7aeb874ded | |||
| ffc695f7b8 | |||
| 93cb621af3 | |||
| 3a34680196 | |||
| 2335a3130a | |||
| 0bc4b69f52 | |||
| 43c5c1276d | |||
| a3ebfd9bbd | |||
| af5b894e62 | |||
| c982066235 | |||
| 1f6c1522b6 | |||
| bae83ba78e | |||
| 0d0aeab4ee | |||
| 7fe91339ad | |||
| 44dea1d208 | |||
| 6d3be40022 | |||
| 07773f92a0 | |||
| dbc4a2b730 | |||
| df15a78aac | |||
| 61b517edfa | |||
| 082b342f65 | |||
| 9a536ee4b9 | |||
| 677f04cab2 | |||
| 3ddc35cddc | |||
| ae211226ef | |||
| 6662611347 | |||
| c4b988c632 | |||
| 2b1ee8cd5c | |||
| e8cfc2b91e | |||
| de54404ab7 | |||
| f8c3b64274 | 
							
								
								
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -100,9 +100,6 @@ ipython_config.py | |||||||
| # pyenv | # pyenv | ||||||
| .python-version | .python-version | ||||||
|  |  | ||||||
| # celery beat schedule file |  | ||||||
| celerybeat-schedule |  | ||||||
|  |  | ||||||
| # SageMath parsed files | # SageMath parsed files | ||||||
| *.sage.py | *.sage.py | ||||||
|  |  | ||||||
| @ -166,8 +163,6 @@ dmypy.json | |||||||
|  |  | ||||||
| # pyenv | # pyenv | ||||||
|  |  | ||||||
| # celery beat schedule file |  | ||||||
|  |  | ||||||
| # SageMath parsed files | # SageMath parsed files | ||||||
|  |  | ||||||
| # Environments | # Environments | ||||||
|  | |||||||
| @ -122,6 +122,7 @@ ENV UV_NO_BINARY_PACKAGE="cryptography lxml python-kadmin-rs xmlsec" | |||||||
|  |  | ||||||
| RUN --mount=type=bind,target=pyproject.toml,src=pyproject.toml \ | RUN --mount=type=bind,target=pyproject.toml,src=pyproject.toml \ | ||||||
|     --mount=type=bind,target=uv.lock,src=uv.lock \ |     --mount=type=bind,target=uv.lock,src=uv.lock \ | ||||||
|  |     --mount=type=bind,target=packages,src=packages \ | ||||||
|     --mount=type=cache,target=/root/.cache/uv \ |     --mount=type=cache,target=/root/.cache/uv \ | ||||||
|     uv sync --frozen --no-install-project --no-dev |     uv sync --frozen --no-install-project --no-dev | ||||||
|  |  | ||||||
| @ -167,6 +168,7 @@ COPY ./blueprints /blueprints | |||||||
| COPY ./lifecycle/ /lifecycle | COPY ./lifecycle/ /lifecycle | ||||||
| COPY ./authentik/sources/kerberos/krb5.conf /etc/krb5.conf | COPY ./authentik/sources/kerberos/krb5.conf /etc/krb5.conf | ||||||
| COPY --from=go-builder /go/authentik /bin/authentik | COPY --from=go-builder /go/authentik /bin/authentik | ||||||
|  | COPY ./packages/ /ak-root/packages | ||||||
| COPY --from=python-deps /ak-root/.venv /ak-root/.venv | COPY --from=python-deps /ak-root/.venv /ak-root/.venv | ||||||
| COPY --from=node-builder /work/web/dist/ /web/dist/ | COPY --from=node-builder /work/web/dist/ /web/dist/ | ||||||
| COPY --from=node-builder /work/web/authentik/ /web/authentik/ | COPY --from=node-builder /work/web/authentik/ /web/authentik/ | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								Makefile
									
									
									
									
									
								
							| @ -6,7 +6,7 @@ PWD = $(shell pwd) | |||||||
| UID = $(shell id -u) | UID = $(shell id -u) | ||||||
| GID = $(shell id -g) | GID = $(shell id -g) | ||||||
| NPM_VERSION = $(shell python -m scripts.generate_semver) | NPM_VERSION = $(shell python -m scripts.generate_semver) | ||||||
| PY_SOURCES = authentik tests scripts lifecycle .github | PY_SOURCES = authentik packages tests scripts lifecycle .github | ||||||
| DOCKER_IMAGE ?= "authentik:test" | DOCKER_IMAGE ?= "authentik:test" | ||||||
|  |  | ||||||
| GEN_API_TS = gen-ts-api | GEN_API_TS = gen-ts-api | ||||||
|  | |||||||
| @ -41,7 +41,7 @@ class VersionSerializer(PassiveSerializer): | |||||||
|             return __version__ |             return __version__ | ||||||
|         version_in_cache = cache.get(VERSION_CACHE_KEY) |         version_in_cache = cache.get(VERSION_CACHE_KEY) | ||||||
|         if not version_in_cache:  # pragma: no cover |         if not version_in_cache:  # pragma: no cover | ||||||
|             update_latest_version.delay() |             update_latest_version.send() | ||||||
|             return __version__ |             return __version__ | ||||||
|         return version_in_cache |         return version_in_cache | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,57 +0,0 @@ | |||||||
| """authentik administration overview""" |  | ||||||
|  |  | ||||||
| from socket import gethostname |  | ||||||
|  |  | ||||||
| from django.conf import settings |  | ||||||
| from drf_spectacular.utils import extend_schema, inline_serializer |  | ||||||
| from packaging.version import parse |  | ||||||
| from rest_framework.fields import BooleanField, CharField |  | ||||||
| from rest_framework.request import Request |  | ||||||
| from rest_framework.response import Response |  | ||||||
| from rest_framework.views import APIView |  | ||||||
|  |  | ||||||
| from authentik import get_full_version |  | ||||||
| from authentik.rbac.permissions import HasPermission |  | ||||||
| from authentik.root.celery import CELERY_APP |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class WorkerView(APIView): |  | ||||||
|     """Get currently connected worker count.""" |  | ||||||
|  |  | ||||||
|     permission_classes = [HasPermission("authentik_rbac.view_system_info")] |  | ||||||
|  |  | ||||||
|     @extend_schema( |  | ||||||
|         responses=inline_serializer( |  | ||||||
|             "Worker", |  | ||||||
|             fields={ |  | ||||||
|                 "worker_id": CharField(), |  | ||||||
|                 "version": CharField(), |  | ||||||
|                 "version_matching": BooleanField(), |  | ||||||
|             }, |  | ||||||
|             many=True, |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     def get(self, request: Request) -> Response: |  | ||||||
|         """Get currently connected worker count.""" |  | ||||||
|         raw: list[dict[str, dict]] = CELERY_APP.control.ping(timeout=0.5) |  | ||||||
|         our_version = parse(get_full_version()) |  | ||||||
|         response = [] |  | ||||||
|         for worker in raw: |  | ||||||
|             key = list(worker.keys())[0] |  | ||||||
|             version = worker[key].get("version") |  | ||||||
|             version_matching = False |  | ||||||
|             if version: |  | ||||||
|                 version_matching = parse(version) == our_version |  | ||||||
|             response.append( |  | ||||||
|                 {"worker_id": key, "version": version, "version_matching": version_matching} |  | ||||||
|             ) |  | ||||||
|         # In debug we run with `task_always_eager`, so tasks are ran on the main process |  | ||||||
|         if settings.DEBUG:  # pragma: no cover |  | ||||||
|             response.append( |  | ||||||
|                 { |  | ||||||
|                     "worker_id": f"authentik-debug@{gethostname()}", |  | ||||||
|                     "version": get_full_version(), |  | ||||||
|                     "version_matching": True, |  | ||||||
|                 } |  | ||||||
|             ) |  | ||||||
|         return Response(response) |  | ||||||
| @ -3,6 +3,9 @@ | |||||||
| from prometheus_client import Info | from prometheus_client import Info | ||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
|  | from authentik.lib.config import CONFIG | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
|  | from authentik.tasks.schedules.lib import ScheduleSpec | ||||||
|  |  | ||||||
| PROM_INFO = Info("authentik_version", "Currently running authentik version") | PROM_INFO = Info("authentik_version", "Currently running authentik version") | ||||||
|  |  | ||||||
| @ -30,3 +33,15 @@ class AuthentikAdminConfig(ManagedAppConfig): | |||||||
|             notification_version = notification.event.context["new_version"] |             notification_version = notification.event.context["new_version"] | ||||||
|             if LOCAL_VERSION >= parse(notification_version): |             if LOCAL_VERSION >= parse(notification_version): | ||||||
|                 notification.delete() |                 notification.delete() | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def global_schedule_specs(self) -> list[ScheduleSpec]: | ||||||
|  |         from authentik.admin.tasks import update_latest_version | ||||||
|  |  | ||||||
|  |         return [ | ||||||
|  |             ScheduleSpec( | ||||||
|  |                 actor=update_latest_version, | ||||||
|  |                 crontab=f"{fqdn_rand('admin_latest_version')} * * * *", | ||||||
|  |                 paused=CONFIG.get_bool("disable_update_check"), | ||||||
|  |             ), | ||||||
|  |         ] | ||||||
|  | |||||||
| @ -1,15 +0,0 @@ | |||||||
| """authentik admin settings""" |  | ||||||
|  |  | ||||||
| from celery.schedules import crontab |  | ||||||
| from django_tenants.utils import get_public_schema_name |  | ||||||
|  |  | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
|  |  | ||||||
| CELERY_BEAT_SCHEDULE = { |  | ||||||
|     "admin_latest_version": { |  | ||||||
|         "task": "authentik.admin.tasks.update_latest_version", |  | ||||||
|         "schedule": crontab(minute=fqdn_rand("admin_latest_version"), hour="*"), |  | ||||||
|         "tenant_schemas": [get_public_schema_name()], |  | ||||||
|         "options": {"queue": "authentik_scheduled"}, |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,35 +0,0 @@ | |||||||
| """admin signals""" |  | ||||||
|  |  | ||||||
| from django.dispatch import receiver |  | ||||||
| from packaging.version import parse |  | ||||||
| from prometheus_client import Gauge |  | ||||||
|  |  | ||||||
| from authentik import get_full_version |  | ||||||
| from authentik.root.celery import CELERY_APP |  | ||||||
| from authentik.root.monitoring import monitoring_set |  | ||||||
|  |  | ||||||
| GAUGE_WORKERS = Gauge( |  | ||||||
|     "authentik_admin_workers", |  | ||||||
|     "Currently connected workers, their versions and if they are the same version as authentik", |  | ||||||
|     ["version", "version_matched"], |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| _version = parse(get_full_version()) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(monitoring_set) |  | ||||||
| def monitoring_set_workers(sender, **kwargs): |  | ||||||
|     """Set worker gauge""" |  | ||||||
|     raw: list[dict[str, dict]] = CELERY_APP.control.ping(timeout=0.5) |  | ||||||
|     worker_version_count = {} |  | ||||||
|     for worker in raw: |  | ||||||
|         key = list(worker.keys())[0] |  | ||||||
|         version = worker[key].get("version") |  | ||||||
|         version_matching = False |  | ||||||
|         if version: |  | ||||||
|             version_matching = parse(version) == _version |  | ||||||
|         worker_version_count.setdefault(version, {"count": 0, "matching": version_matching}) |  | ||||||
|         worker_version_count[version]["count"] += 1 |  | ||||||
|     for version, stats in worker_version_count.items(): |  | ||||||
|         GAUGE_WORKERS.labels(version, stats["matching"]).set(stats["count"]) |  | ||||||
| @ -2,6 +2,8 @@ | |||||||
|  |  | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from django_dramatiq_postgres.middleware import CurrentTask | ||||||
|  | from dramatiq import actor | ||||||
| from packaging.version import parse | from packaging.version import parse | ||||||
| from requests import RequestException | from requests import RequestException | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| @ -9,10 +11,9 @@ from structlog.stdlib import get_logger | |||||||
| from authentik import __version__, get_build_hash | from authentik import __version__, get_build_hash | ||||||
| from authentik.admin.apps import PROM_INFO | from authentik.admin.apps import PROM_INFO | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task |  | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.lib.utils.http import get_http_session | from authentik.lib.utils.http import get_http_session | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.tasks.models import Task | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| VERSION_NULL = "0.0.0" | VERSION_NULL = "0.0.0" | ||||||
| @ -32,13 +33,12 @@ def _set_prom_info(): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=SystemTask) | @actor(description=_("Update latest version info.")) | ||||||
| @prefill_task | def update_latest_version(): | ||||||
| def update_latest_version(self: SystemTask): |     self: Task = CurrentTask.get_task() | ||||||
|     """Update latest version info""" |  | ||||||
|     if CONFIG.get_bool("disable_update_check"): |     if CONFIG.get_bool("disable_update_check"): | ||||||
|         cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) |         cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) | ||||||
|         self.set_status(TaskStatus.WARNING, "Version check disabled.") |         self.info("Version check disabled.") | ||||||
|         return |         return | ||||||
|     try: |     try: | ||||||
|         response = get_http_session().get( |         response = get_http_session().get( | ||||||
| @ -48,7 +48,7 @@ def update_latest_version(self: SystemTask): | |||||||
|         data = response.json() |         data = response.json() | ||||||
|         upstream_version = data.get("stable", {}).get("version") |         upstream_version = data.get("stable", {}).get("version") | ||||||
|         cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT) |         cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT) | ||||||
|         self.set_status(TaskStatus.SUCCESSFUL, "Successfully updated latest Version") |         self.info("Successfully updated latest Version") | ||||||
|         _set_prom_info() |         _set_prom_info() | ||||||
|         # Check if upstream version is newer than what we're running, |         # Check if upstream version is newer than what we're running, | ||||||
|         # and if no event exists yet, create one. |         # and if no event exists yet, create one. | ||||||
| @ -71,7 +71,7 @@ def update_latest_version(self: SystemTask): | |||||||
|             ).save() |             ).save() | ||||||
|     except (RequestException, IndexError) as exc: |     except (RequestException, IndexError) as exc: | ||||||
|         cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) |         cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) | ||||||
|         self.set_error(exc) |         raise exc | ||||||
|  |  | ||||||
|  |  | ||||||
| _set_prom_info() | _set_prom_info() | ||||||
|  | |||||||
| @ -29,13 +29,6 @@ class TestAdminAPI(TestCase): | |||||||
|         body = loads(response.content) |         body = loads(response.content) | ||||||
|         self.assertEqual(body["version_current"], __version__) |         self.assertEqual(body["version_current"], __version__) | ||||||
|  |  | ||||||
|     def test_workers(self): |  | ||||||
|         """Test Workers API""" |  | ||||||
|         response = self.client.get(reverse("authentik_api:admin_workers")) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|         body = loads(response.content) |  | ||||||
|         self.assertEqual(len(body), 0) |  | ||||||
|  |  | ||||||
|     def test_apps(self): |     def test_apps(self): | ||||||
|         """Test apps API""" |         """Test apps API""" | ||||||
|         response = self.client.get(reverse("authentik_api:apps-list")) |         response = self.client.get(reverse("authentik_api:apps-list")) | ||||||
|  | |||||||
| @ -30,7 +30,7 @@ class TestAdminTasks(TestCase): | |||||||
|         """Test Update checker with valid response""" |         """Test Update checker with valid response""" | ||||||
|         with Mocker() as mocker, CONFIG.patch("disable_update_check", False): |         with Mocker() as mocker, CONFIG.patch("disable_update_check", False): | ||||||
|             mocker.get("https://version.goauthentik.io/version.json", json=RESPONSE_VALID) |             mocker.get("https://version.goauthentik.io/version.json", json=RESPONSE_VALID) | ||||||
|             update_latest_version.delay().get() |             update_latest_version.send() | ||||||
|             self.assertEqual(cache.get(VERSION_CACHE_KEY), "99999999.9999999") |             self.assertEqual(cache.get(VERSION_CACHE_KEY), "99999999.9999999") | ||||||
|             self.assertTrue( |             self.assertTrue( | ||||||
|                 Event.objects.filter( |                 Event.objects.filter( | ||||||
| @ -40,7 +40,7 @@ class TestAdminTasks(TestCase): | |||||||
|                 ).exists() |                 ).exists() | ||||||
|             ) |             ) | ||||||
|             # test that a consecutive check doesn't create a duplicate event |             # test that a consecutive check doesn't create a duplicate event | ||||||
|             update_latest_version.delay().get() |             update_latest_version.send() | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 len( |                 len( | ||||||
|                     Event.objects.filter( |                     Event.objects.filter( | ||||||
| @ -56,7 +56,7 @@ class TestAdminTasks(TestCase): | |||||||
|         """Test Update checker with invalid response""" |         """Test Update checker with invalid response""" | ||||||
|         with Mocker() as mocker: |         with Mocker() as mocker: | ||||||
|             mocker.get("https://version.goauthentik.io/version.json", status_code=400) |             mocker.get("https://version.goauthentik.io/version.json", status_code=400) | ||||||
|             update_latest_version.delay().get() |             update_latest_version.send() | ||||||
|             self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") |             self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") | ||||||
|             self.assertFalse( |             self.assertFalse( | ||||||
|                 Event.objects.filter( |                 Event.objects.filter( | ||||||
| @ -67,14 +67,15 @@ class TestAdminTasks(TestCase): | |||||||
|     def test_version_disabled(self): |     def test_version_disabled(self): | ||||||
|         """Test Update checker while its disabled""" |         """Test Update checker while its disabled""" | ||||||
|         with CONFIG.patch("disable_update_check", True): |         with CONFIG.patch("disable_update_check", True): | ||||||
|             update_latest_version.delay().get() |             update_latest_version.send() | ||||||
|             self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") |             self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") | ||||||
|  |  | ||||||
|     def test_clear_update_notifications(self): |     def test_clear_update_notifications(self): | ||||||
|         """Test clear of previous notification""" |         """Test clear of previous notification""" | ||||||
|         admin_config = apps.get_app_config("authentik_admin") |         admin_config = apps.get_app_config("authentik_admin") | ||||||
|         Event.objects.create( |         Event.objects.create( | ||||||
|             action=EventAction.UPDATE_AVAILABLE, context={"new_version": "99999999.9999999.9999999"} |             action=EventAction.UPDATE_AVAILABLE, | ||||||
|  |             context={"new_version": "99999999.9999999.9999999"}, | ||||||
|         ) |         ) | ||||||
|         Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={"new_version": "1.1.1"}) |         Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={"new_version": "1.1.1"}) | ||||||
|         Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={}) |         Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={}) | ||||||
|  | |||||||
| @ -6,13 +6,11 @@ from authentik.admin.api.meta import AppsViewSet, ModelViewSet | |||||||
| from authentik.admin.api.system import SystemView | from authentik.admin.api.system import SystemView | ||||||
| from authentik.admin.api.version import VersionView | from authentik.admin.api.version import VersionView | ||||||
| from authentik.admin.api.version_history import VersionHistoryViewSet | from authentik.admin.api.version_history import VersionHistoryViewSet | ||||||
| from authentik.admin.api.workers import WorkerView |  | ||||||
|  |  | ||||||
| api_urlpatterns = [ | api_urlpatterns = [ | ||||||
|     ("admin/apps", AppsViewSet, "apps"), |     ("admin/apps", AppsViewSet, "apps"), | ||||||
|     ("admin/models", ModelViewSet, "models"), |     ("admin/models", ModelViewSet, "models"), | ||||||
|     path("admin/version/", VersionView.as_view(), name="admin_version"), |     path("admin/version/", VersionView.as_view(), name="admin_version"), | ||||||
|     ("admin/version/history", VersionHistoryViewSet, "version_history"), |     ("admin/version/history", VersionHistoryViewSet, "version_history"), | ||||||
|     path("admin/workers/", WorkerView.as_view(), name="admin_workers"), |  | ||||||
|     path("admin/system/", SystemView.as_view(), name="admin_system"), |     path("admin/system/", SystemView.as_view(), name="admin_system"), | ||||||
| ] | ] | ||||||
|  | |||||||
| @ -39,7 +39,7 @@ class BlueprintInstanceSerializer(ModelSerializer): | |||||||
|         """Ensure the path (if set) specified is retrievable""" |         """Ensure the path (if set) specified is retrievable""" | ||||||
|         if path == "" or path.startswith(OCI_PREFIX): |         if path == "" or path.startswith(OCI_PREFIX): | ||||||
|             return path |             return path | ||||||
|         files: list[dict] = blueprints_find_dict.delay().get() |         files: list[dict] = blueprints_find_dict.send().get_result(block=True) | ||||||
|         if path not in [file["path"] for file in files]: |         if path not in [file["path"] for file in files]: | ||||||
|             raise ValidationError(_("Blueprint file does not exist")) |             raise ValidationError(_("Blueprint file does not exist")) | ||||||
|         return path |         return path | ||||||
| @ -115,7 +115,7 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet): | |||||||
|     @action(detail=False, pagination_class=None, filter_backends=[]) |     @action(detail=False, pagination_class=None, filter_backends=[]) | ||||||
|     def available(self, request: Request) -> Response: |     def available(self, request: Request) -> Response: | ||||||
|         """Get blueprints""" |         """Get blueprints""" | ||||||
|         files: list[dict] = blueprints_find_dict.delay().get() |         files: list[dict] = blueprints_find_dict.send().get_result(block=True) | ||||||
|         return Response(files) |         return Response(files) | ||||||
|  |  | ||||||
|     @permission_required("authentik_blueprints.view_blueprintinstance") |     @permission_required("authentik_blueprints.view_blueprintinstance") | ||||||
| @ -129,5 +129,5 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet): | |||||||
|     def apply(self, request: Request, *args, **kwargs) -> Response: |     def apply(self, request: Request, *args, **kwargs) -> Response: | ||||||
|         """Apply a blueprint""" |         """Apply a blueprint""" | ||||||
|         blueprint = self.get_object() |         blueprint = self.get_object() | ||||||
|         apply_blueprint.delay(str(blueprint.pk)).get() |         apply_blueprint.send_with_options(args=(blueprint.pk,), rel_obj=blueprint) | ||||||
|         return self.retrieve(request, *args, **kwargs) |         return self.retrieve(request, *args, **kwargs) | ||||||
|  | |||||||
| @ -6,9 +6,12 @@ from inspect import ismethod | |||||||
|  |  | ||||||
| from django.apps import AppConfig | from django.apps import AppConfig | ||||||
| from django.db import DatabaseError, InternalError, ProgrammingError | from django.db import DatabaseError, InternalError, ProgrammingError | ||||||
|  | from dramatiq.broker import get_broker | ||||||
| from structlog.stdlib import BoundLogger, get_logger | from structlog.stdlib import BoundLogger, get_logger | ||||||
|  |  | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
| from authentik.root.signals import startup | from authentik.root.signals import startup | ||||||
|  | from authentik.tasks.schedules.lib import ScheduleSpec | ||||||
|  |  | ||||||
|  |  | ||||||
| class ManagedAppConfig(AppConfig): | class ManagedAppConfig(AppConfig): | ||||||
| @ -34,7 +37,7 @@ class ManagedAppConfig(AppConfig): | |||||||
|  |  | ||||||
|     def import_related(self): |     def import_related(self): | ||||||
|         """Automatically import related modules which rely on just being imported |         """Automatically import related modules which rely on just being imported | ||||||
|         to register themselves (mainly django signals and celery tasks)""" |         to register themselves (mainly django signals and tasks)""" | ||||||
|  |  | ||||||
|         def import_relative(rel_module: str): |         def import_relative(rel_module: str): | ||||||
|             try: |             try: | ||||||
| @ -80,6 +83,16 @@ class ManagedAppConfig(AppConfig): | |||||||
|         func._authentik_managed_reconcile = ManagedAppConfig.RECONCILE_GLOBAL_CATEGORY |         func._authentik_managed_reconcile = ManagedAppConfig.RECONCILE_GLOBAL_CATEGORY | ||||||
|         return func |         return func | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def tenant_schedule_specs(self) -> list[ScheduleSpec]: | ||||||
|  |         """Get a list of schedule specs that must exist in each tenant""" | ||||||
|  |         return [] | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def global_schedule_specs(self) -> list[ScheduleSpec]: | ||||||
|  |         """Get a list of schedule specs that must exist in the default tenant""" | ||||||
|  |         return [] | ||||||
|  |  | ||||||
|     def _reconcile_tenant(self) -> None: |     def _reconcile_tenant(self) -> None: | ||||||
|         """reconcile ourselves for tenanted methods""" |         """reconcile ourselves for tenanted methods""" | ||||||
|         from authentik.tenants.models import Tenant |         from authentik.tenants.models import Tenant | ||||||
| @ -100,8 +113,12 @@ class ManagedAppConfig(AppConfig): | |||||||
|         """ |         """ | ||||||
|         from django_tenants.utils import get_public_schema_name, schema_context |         from django_tenants.utils import get_public_schema_name, schema_context | ||||||
|  |  | ||||||
|  |         try: | ||||||
|             with schema_context(get_public_schema_name()): |             with schema_context(get_public_schema_name()): | ||||||
|                 self._reconcile(self.RECONCILE_GLOBAL_CATEGORY) |                 self._reconcile(self.RECONCILE_GLOBAL_CATEGORY) | ||||||
|  |         except (DatabaseError, ProgrammingError, InternalError) as exc: | ||||||
|  |             self.logger.debug("Failed to access database to run reconcile", exc=exc) | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikBlueprintsConfig(ManagedAppConfig): | class AuthentikBlueprintsConfig(ManagedAppConfig): | ||||||
| @ -112,19 +129,29 @@ class AuthentikBlueprintsConfig(ManagedAppConfig): | |||||||
|     verbose_name = "authentik Blueprints" |     verbose_name = "authentik Blueprints" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     @ManagedAppConfig.reconcile_global |  | ||||||
|     def load_blueprints_v1_tasks(self): |  | ||||||
|         """Load v1 tasks""" |  | ||||||
|         self.import_module("authentik.blueprints.v1.tasks") |  | ||||||
|  |  | ||||||
|     @ManagedAppConfig.reconcile_tenant |  | ||||||
|     def blueprints_discovery(self): |  | ||||||
|         """Run blueprint discovery""" |  | ||||||
|         from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints |  | ||||||
|  |  | ||||||
|         blueprints_discovery.delay() |  | ||||||
|         clear_failed_blueprints.delay() |  | ||||||
|  |  | ||||||
|     def import_models(self): |     def import_models(self): | ||||||
|         super().import_models() |         super().import_models() | ||||||
|         self.import_module("authentik.blueprints.v1.meta.apply_blueprint") |         self.import_module("authentik.blueprints.v1.meta.apply_blueprint") | ||||||
|  |  | ||||||
|  |     @ManagedAppConfig.reconcile_global | ||||||
|  |     def tasks_middlewares(self): | ||||||
|  |         from authentik.blueprints.v1.tasks import BlueprintWatcherMiddleware | ||||||
|  |  | ||||||
|  |         get_broker().add_middleware(BlueprintWatcherMiddleware()) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def tenant_schedule_specs(self) -> list[ScheduleSpec]: | ||||||
|  |         from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints | ||||||
|  |  | ||||||
|  |         return [ | ||||||
|  |             ScheduleSpec( | ||||||
|  |                 actor=blueprints_discovery, | ||||||
|  |                 crontab=f"{fqdn_rand('blueprints_v1_discover')} * * * *", | ||||||
|  |                 send_on_startup=True, | ||||||
|  |             ), | ||||||
|  |             ScheduleSpec( | ||||||
|  |                 actor=clear_failed_blueprints, | ||||||
|  |                 crontab=f"{fqdn_rand('blueprints_v1_cleanup')} * * * *", | ||||||
|  |                 send_on_startup=True, | ||||||
|  |             ), | ||||||
|  |         ] | ||||||
|  | |||||||
| @ -3,6 +3,7 @@ | |||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
|  | from django.contrib.contenttypes.fields import GenericRelation | ||||||
| from django.contrib.postgres.fields import ArrayField | from django.contrib.postgres.fields import ArrayField | ||||||
| from django.db import models | from django.db import models | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| @ -71,6 +72,13 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): | |||||||
|     enabled = models.BooleanField(default=True) |     enabled = models.BooleanField(default=True) | ||||||
|     managed_models = ArrayField(models.TextField(), default=list) |     managed_models = ArrayField(models.TextField(), default=list) | ||||||
|  |  | ||||||
|  |     # Manual link to tasks instead of using TasksModel because of loop imports | ||||||
|  |     tasks = GenericRelation( | ||||||
|  |         "authentik_tasks.Task", | ||||||
|  |         content_type_field="rel_obj_content_type", | ||||||
|  |         object_id_field="rel_obj_id", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         verbose_name = _("Blueprint Instance") |         verbose_name = _("Blueprint Instance") | ||||||
|         verbose_name_plural = _("Blueprint Instances") |         verbose_name_plural = _("Blueprint Instances") | ||||||
|  | |||||||
| @ -1,18 +0,0 @@ | |||||||
| """blueprint Settings""" |  | ||||||
|  |  | ||||||
| from celery.schedules import crontab |  | ||||||
|  |  | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
|  |  | ||||||
| CELERY_BEAT_SCHEDULE = { |  | ||||||
|     "blueprints_v1_discover": { |  | ||||||
|         "task": "authentik.blueprints.v1.tasks.blueprints_discovery", |  | ||||||
|         "schedule": crontab(minute=fqdn_rand("blueprints_v1_discover"), hour="*"), |  | ||||||
|         "options": {"queue": "authentik_scheduled"}, |  | ||||||
|     }, |  | ||||||
|     "blueprints_v1_cleanup": { |  | ||||||
|         "task": "authentik.blueprints.v1.tasks.clear_failed_blueprints", |  | ||||||
|         "schedule": crontab(minute=fqdn_rand("blueprints_v1_cleanup"), hour="*"), |  | ||||||
|         "options": {"queue": "authentik_scheduled"}, |  | ||||||
|     }, |  | ||||||
| } |  | ||||||
							
								
								
									
										2
									
								
								authentik/blueprints/tasks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								authentik/blueprints/tasks.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | |||||||
|  | # Import all v1 tasks for auto task discovery | ||||||
|  | from authentik.blueprints.v1.tasks import *  # noqa: F403 | ||||||
| @ -54,7 +54,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | |||||||
|             file.seek(0) |             file.seek(0) | ||||||
|             file_hash = sha512(file.read().encode()).hexdigest() |             file_hash = sha512(file.read().encode()).hexdigest() | ||||||
|             file.flush() |             file.flush() | ||||||
|             blueprints_discovery() |             blueprints_discovery.send() | ||||||
|             instance = BlueprintInstance.objects.filter(name=blueprint_id).first() |             instance = BlueprintInstance.objects.filter(name=blueprint_id).first() | ||||||
|             self.assertEqual(instance.last_applied_hash, file_hash) |             self.assertEqual(instance.last_applied_hash, file_hash) | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
| @ -82,7 +82,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | |||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|             file.flush() |             file.flush() | ||||||
|             blueprints_discovery() |             blueprints_discovery.send() | ||||||
|             blueprint = BlueprintInstance.objects.filter(name="foo").first() |             blueprint = BlueprintInstance.objects.filter(name="foo").first() | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 blueprint.last_applied_hash, |                 blueprint.last_applied_hash, | ||||||
| @ -107,7 +107,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | |||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|             file.flush() |             file.flush() | ||||||
|             blueprints_discovery() |             blueprints_discovery.send() | ||||||
|             blueprint.refresh_from_db() |             blueprint.refresh_from_db() | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 blueprint.last_applied_hash, |                 blueprint.last_applied_hash, | ||||||
|  | |||||||
| @ -57,7 +57,6 @@ from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import ( | |||||||
|     EndpointDeviceConnection, |     EndpointDeviceConnection, | ||||||
| ) | ) | ||||||
| from authentik.events.logs import LogEvent, capture_logs | from authentik.events.logs import LogEvent, capture_logs | ||||||
| from authentik.events.models import SystemTask |  | ||||||
| from authentik.events.utils import cleanse_dict | from authentik.events.utils import cleanse_dict | ||||||
| from authentik.flows.models import FlowToken, Stage | from authentik.flows.models import FlowToken, Stage | ||||||
| from authentik.lib.models import SerializerModel | from authentik.lib.models import SerializerModel | ||||||
| @ -77,6 +76,7 @@ from authentik.providers.scim.models import SCIMProviderGroup, SCIMProviderUser | |||||||
| from authentik.rbac.models import Role | from authentik.rbac.models import Role | ||||||
| from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser | from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser | ||||||
| from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType | from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType | ||||||
|  | from authentik.tasks.models import Task | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
|  |  | ||||||
| # Context set when the serializer is created in a blueprint context | # Context set when the serializer is created in a blueprint context | ||||||
| @ -118,7 +118,7 @@ def excluded_models() -> list[type[Model]]: | |||||||
|         SCIMProviderGroup, |         SCIMProviderGroup, | ||||||
|         SCIMProviderUser, |         SCIMProviderUser, | ||||||
|         Tenant, |         Tenant, | ||||||
|         SystemTask, |         Task, | ||||||
|         ConnectionToken, |         ConnectionToken, | ||||||
|         AuthorizationCode, |         AuthorizationCode, | ||||||
|         AccessToken, |         AccessToken, | ||||||
|  | |||||||
| @ -44,7 +44,7 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer): | |||||||
|             return MetaResult() |             return MetaResult() | ||||||
|         LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance) |         LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance) | ||||||
|  |  | ||||||
|         apply_blueprint(str(self.blueprint_instance.pk)) |         apply_blueprint(self.blueprint_instance.pk) | ||||||
|         return MetaResult() |         return MetaResult() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -4,12 +4,17 @@ from dataclasses import asdict, dataclass, field | |||||||
| from hashlib import sha512 | from hashlib import sha512 | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from sys import platform | from sys import platform | ||||||
|  | from uuid import UUID | ||||||
|  |  | ||||||
| from dacite.core import from_dict | from dacite.core import from_dict | ||||||
|  | from django.conf import settings | ||||||
| from django.db import DatabaseError, InternalError, ProgrammingError | from django.db import DatabaseError, InternalError, ProgrammingError | ||||||
| from django.utils.text import slugify | from django.utils.text import slugify | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from django_dramatiq_postgres.middleware import CurrentTask, CurrentTaskNotFound | ||||||
|  | from dramatiq.actor import actor | ||||||
|  | from dramatiq.middleware import Middleware | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| from watchdog.events import ( | from watchdog.events import ( | ||||||
|     FileCreatedEvent, |     FileCreatedEvent, | ||||||
| @ -31,15 +36,13 @@ from authentik.blueprints.v1.importer import Importer | |||||||
| from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE | from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE | ||||||
| from authentik.blueprints.v1.oci import OCI_PREFIX | from authentik.blueprints.v1.oci import OCI_PREFIX | ||||||
| from authentik.events.logs import capture_logs | from authentik.events.logs import capture_logs | ||||||
| from authentik.events.models import TaskStatus |  | ||||||
| from authentik.events.system_tasks import SystemTask, prefill_task |  | ||||||
| from authentik.events.utils import sanitize_dict | from authentik.events.utils import sanitize_dict | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.tasks.models import Task | ||||||
|  | from authentik.tasks.schedules.models import Schedule | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| _file_watcher_started = False |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass | @dataclass | ||||||
| @ -53,13 +56,9 @@ class BlueprintFile: | |||||||
|     meta: BlueprintMetadata | None = field(default=None) |     meta: BlueprintMetadata | None = field(default=None) | ||||||
|  |  | ||||||
|  |  | ||||||
| def start_blueprint_watcher(): | class BlueprintWatcherMiddleware(Middleware): | ||||||
|     """Start blueprint watcher, if it's not running already.""" |     def start_blueprint_watcher(self): | ||||||
|     # This function might be called twice since it's called on celery startup |         """Start blueprint watcher""" | ||||||
|  |  | ||||||
|     global _file_watcher_started  # noqa: PLW0603 |  | ||||||
|     if _file_watcher_started: |  | ||||||
|         return |  | ||||||
|         observer = Observer() |         observer = Observer() | ||||||
|         kwargs = {} |         kwargs = {} | ||||||
|         if platform.startswith("linux"): |         if platform.startswith("linux"): | ||||||
| @ -68,7 +67,10 @@ def start_blueprint_watcher(): | |||||||
|             BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs |             BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs | ||||||
|         ) |         ) | ||||||
|         observer.start() |         observer.start() | ||||||
|     _file_watcher_started = True |  | ||||||
|  |     def after_worker_boot(self, broker, worker): | ||||||
|  |         if not settings.TEST: | ||||||
|  |             self.start_blueprint_watcher() | ||||||
|  |  | ||||||
|  |  | ||||||
| class BlueprintEventHandler(FileSystemEventHandler): | class BlueprintEventHandler(FileSystemEventHandler): | ||||||
| @ -92,7 +94,7 @@ class BlueprintEventHandler(FileSystemEventHandler): | |||||||
|         LOGGER.debug("new blueprint file created, starting discovery") |         LOGGER.debug("new blueprint file created, starting discovery") | ||||||
|         for tenant in Tenant.objects.filter(ready=True): |         for tenant in Tenant.objects.filter(ready=True): | ||||||
|             with tenant: |             with tenant: | ||||||
|                 blueprints_discovery.delay() |                 Schedule.dispatch_by_actor(blueprints_discovery) | ||||||
|  |  | ||||||
|     def on_modified(self, event: FileSystemEvent): |     def on_modified(self, event: FileSystemEvent): | ||||||
|         """Process file modification""" |         """Process file modification""" | ||||||
| @ -103,14 +105,14 @@ class BlueprintEventHandler(FileSystemEventHandler): | |||||||
|             with tenant: |             with tenant: | ||||||
|                 for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True): |                 for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True): | ||||||
|                     LOGGER.debug("modified blueprint file, starting apply", instance=instance) |                     LOGGER.debug("modified blueprint file, starting apply", instance=instance) | ||||||
|                     apply_blueprint.delay(instance.pk.hex) |                     apply_blueprint.send_with_options(args=(instance.pk,), rel_obj=instance) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task( | @actor( | ||||||
|  |     description=_("Find blueprints as `blueprints_find` does, but return a safe dict."), | ||||||
|     throws=(DatabaseError, ProgrammingError, InternalError), |     throws=(DatabaseError, ProgrammingError, InternalError), | ||||||
| ) | ) | ||||||
| def blueprints_find_dict(): | def blueprints_find_dict(): | ||||||
|     """Find blueprints as `blueprints_find` does, but return a safe dict""" |  | ||||||
|     blueprints = [] |     blueprints = [] | ||||||
|     for blueprint in blueprints_find(): |     for blueprint in blueprints_find(): | ||||||
|         blueprints.append(sanitize_dict(asdict(blueprint))) |         blueprints.append(sanitize_dict(asdict(blueprint))) | ||||||
| @ -146,21 +148,19 @@ def blueprints_find() -> list[BlueprintFile]: | |||||||
|     return blueprints |     return blueprints | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task( | @actor( | ||||||
|     throws=(DatabaseError, ProgrammingError, InternalError), base=SystemTask, bind=True |     description=_("Find blueprints and check if they need to be created in the database."), | ||||||
|  |     throws=(DatabaseError, ProgrammingError, InternalError), | ||||||
| ) | ) | ||||||
| @prefill_task | def blueprints_discovery(path: str | None = None): | ||||||
| def blueprints_discovery(self: SystemTask, path: str | None = None): |     self: Task = CurrentTask.get_task() | ||||||
|     """Find blueprints and check if they need to be created in the database""" |  | ||||||
|     count = 0 |     count = 0 | ||||||
|     for blueprint in blueprints_find(): |     for blueprint in blueprints_find(): | ||||||
|         if path and blueprint.path != path: |         if path and blueprint.path != path: | ||||||
|             continue |             continue | ||||||
|         check_blueprint_v1_file(blueprint) |         check_blueprint_v1_file(blueprint) | ||||||
|         count += 1 |         count += 1 | ||||||
|     self.set_status( |     self.info(f"Successfully imported {count} files.") | ||||||
|         TaskStatus.SUCCESSFUL, _("Successfully imported {count} files.".format(count=count)) |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def check_blueprint_v1_file(blueprint: BlueprintFile): | def check_blueprint_v1_file(blueprint: BlueprintFile): | ||||||
| @ -187,22 +187,26 @@ def check_blueprint_v1_file(blueprint: BlueprintFile): | |||||||
|         ) |         ) | ||||||
|     if instance.last_applied_hash != blueprint.hash: |     if instance.last_applied_hash != blueprint.hash: | ||||||
|         LOGGER.info("Applying blueprint due to changed file", instance=instance, path=instance.path) |         LOGGER.info("Applying blueprint due to changed file", instance=instance, path=instance.path) | ||||||
|         apply_blueprint.delay(str(instance.pk)) |         apply_blueprint.send_with_options(args=(instance.pk,), rel_obj=instance) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task( | @actor(description=_("Apply single blueprint.")) | ||||||
|     bind=True, | def apply_blueprint(instance_pk: UUID): | ||||||
|     base=SystemTask, |     try: | ||||||
| ) |         self: Task = CurrentTask.get_task() | ||||||
| def apply_blueprint(self: SystemTask, instance_pk: str): |     except CurrentTaskNotFound: | ||||||
|     """Apply single blueprint""" |         self = Task() | ||||||
|     self.save_on_success = False |     self.set_uid(str(instance_pk)) | ||||||
|     instance: BlueprintInstance | None = None |     instance: BlueprintInstance | None = None | ||||||
|     try: |     try: | ||||||
|         instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() |         instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() | ||||||
|         if not instance or not instance.enabled: |         if not instance: | ||||||
|  |             self.warning(f"Could not find blueprint {instance_pk}, skipping") | ||||||
|             return |             return | ||||||
|         self.set_uid(slugify(instance.name)) |         self.set_uid(slugify(instance.name)) | ||||||
|  |         if not instance.enabled: | ||||||
|  |             self.info(f"Blueprint {instance.name} is disabled, skipping") | ||||||
|  |             return | ||||||
|         blueprint_content = instance.retrieve() |         blueprint_content = instance.retrieve() | ||||||
|         file_hash = sha512(blueprint_content.encode()).hexdigest() |         file_hash = sha512(blueprint_content.encode()).hexdigest() | ||||||
|         importer = Importer.from_string(blueprint_content, instance.context) |         importer = Importer.from_string(blueprint_content, instance.context) | ||||||
| @ -212,19 +216,18 @@ def apply_blueprint(self: SystemTask, instance_pk: str): | |||||||
|         if not valid: |         if not valid: | ||||||
|             instance.status = BlueprintInstanceStatus.ERROR |             instance.status = BlueprintInstanceStatus.ERROR | ||||||
|             instance.save() |             instance.save() | ||||||
|             self.set_status(TaskStatus.ERROR, *logs) |             self.logs(logs) | ||||||
|             return |             return | ||||||
|         with capture_logs() as logs: |         with capture_logs() as logs: | ||||||
|             applied = importer.apply() |             applied = importer.apply() | ||||||
|             if not applied: |             if not applied: | ||||||
|                 instance.status = BlueprintInstanceStatus.ERROR |                 instance.status = BlueprintInstanceStatus.ERROR | ||||||
|                 instance.save() |                 instance.save() | ||||||
|                 self.set_status(TaskStatus.ERROR, *logs) |                 self.logs(logs) | ||||||
|                 return |                 return | ||||||
|         instance.status = BlueprintInstanceStatus.SUCCESSFUL |         instance.status = BlueprintInstanceStatus.SUCCESSFUL | ||||||
|         instance.last_applied_hash = file_hash |         instance.last_applied_hash = file_hash | ||||||
|         instance.last_applied = now() |         instance.last_applied = now() | ||||||
|         self.set_status(TaskStatus.SUCCESSFUL) |  | ||||||
|     except ( |     except ( | ||||||
|         OSError, |         OSError, | ||||||
|         DatabaseError, |         DatabaseError, | ||||||
| @ -235,15 +238,14 @@ def apply_blueprint(self: SystemTask, instance_pk: str): | |||||||
|     ) as exc: |     ) as exc: | ||||||
|         if instance: |         if instance: | ||||||
|             instance.status = BlueprintInstanceStatus.ERROR |             instance.status = BlueprintInstanceStatus.ERROR | ||||||
|         self.set_error(exc) |         self.error(exc) | ||||||
|     finally: |     finally: | ||||||
|         if instance: |         if instance: | ||||||
|             instance.save() |             instance.save() | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task() | @actor(description=_("Remove blueprints which couldn't be fetched.")) | ||||||
| def clear_failed_blueprints(): | def clear_failed_blueprints(): | ||||||
|     """Remove blueprints which couldn't be fetched""" |  | ||||||
|     # Exclude OCI blueprints as those might be temporarily unavailable |     # Exclude OCI blueprints as those might be temporarily unavailable | ||||||
|     for blueprint in BlueprintInstance.objects.exclude(path__startswith=OCI_PREFIX): |     for blueprint in BlueprintInstance.objects.exclude(path__startswith=OCI_PREFIX): | ||||||
|         try: |         try: | ||||||
|  | |||||||
| @ -9,6 +9,7 @@ class AuthentikBrandsConfig(ManagedAppConfig): | |||||||
|     name = "authentik.brands" |     name = "authentik.brands" | ||||||
|     label = "authentik_brands" |     label = "authentik_brands" | ||||||
|     verbose_name = "authentik Brands" |     verbose_name = "authentik Brands" | ||||||
|  |     default = True | ||||||
|     mountpoints = { |     mountpoints = { | ||||||
|         "authentik.brands.urls_root": "", |         "authentik.brands.urls_root": "", | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -1,8 +1,7 @@ | |||||||
| """authentik core app config""" | """authentik core app config""" | ||||||
|  |  | ||||||
| from django.conf import settings |  | ||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
|  | from authentik.tasks.schedules.lib import ScheduleSpec | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikCoreConfig(ManagedAppConfig): | class AuthentikCoreConfig(ManagedAppConfig): | ||||||
| @ -14,14 +13,6 @@ class AuthentikCoreConfig(ManagedAppConfig): | |||||||
|     mountpoint = "" |     mountpoint = "" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     @ManagedAppConfig.reconcile_global |  | ||||||
|     def debug_worker_hook(self): |  | ||||||
|         """Dispatch startup tasks inline when debugging""" |  | ||||||
|         if settings.DEBUG: |  | ||||||
|             from authentik.root.celery import worker_ready_hook |  | ||||||
|  |  | ||||||
|             worker_ready_hook() |  | ||||||
|  |  | ||||||
|     @ManagedAppConfig.reconcile_tenant |     @ManagedAppConfig.reconcile_tenant | ||||||
|     def source_inbuilt(self): |     def source_inbuilt(self): | ||||||
|         """Reconcile inbuilt source""" |         """Reconcile inbuilt source""" | ||||||
| @ -34,3 +25,18 @@ class AuthentikCoreConfig(ManagedAppConfig): | |||||||
|             }, |             }, | ||||||
|             managed=Source.MANAGED_INBUILT, |             managed=Source.MANAGED_INBUILT, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def tenant_schedule_specs(self) -> list[ScheduleSpec]: | ||||||
|  |         from authentik.core.tasks import clean_expired_models, clean_temporary_users | ||||||
|  |  | ||||||
|  |         return [ | ||||||
|  |             ScheduleSpec( | ||||||
|  |                 actor=clean_expired_models, | ||||||
|  |                 crontab="2-59/5 * * * *", | ||||||
|  |             ), | ||||||
|  |             ScheduleSpec( | ||||||
|  |                 actor=clean_temporary_users, | ||||||
|  |                 crontab="9-59/5 * * * *", | ||||||
|  |             ), | ||||||
|  |         ] | ||||||
|  | |||||||
| @ -1,21 +0,0 @@ | |||||||
| """Run bootstrap tasks""" |  | ||||||
|  |  | ||||||
| from django.core.management.base import BaseCommand |  | ||||||
| from django_tenants.utils import get_public_schema_name |  | ||||||
|  |  | ||||||
| from authentik.root.celery import _get_startup_tasks_all_tenants, _get_startup_tasks_default_tenant |  | ||||||
| from authentik.tenants.models import Tenant |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Command(BaseCommand): |  | ||||||
|     """Run bootstrap tasks to ensure certain objects are created""" |  | ||||||
|  |  | ||||||
|     def handle(self, **options): |  | ||||||
|         for task in _get_startup_tasks_default_tenant(): |  | ||||||
|             with Tenant.objects.get(schema_name=get_public_schema_name()): |  | ||||||
|                 task() |  | ||||||
|  |  | ||||||
|         for task in _get_startup_tasks_all_tenants(): |  | ||||||
|             for tenant in Tenant.objects.filter(ready=True): |  | ||||||
|                 with tenant: |  | ||||||
|                     task() |  | ||||||
| @ -1,47 +0,0 @@ | |||||||
| """Run worker""" |  | ||||||
|  |  | ||||||
| from sys import exit as sysexit |  | ||||||
| from tempfile import tempdir |  | ||||||
|  |  | ||||||
| from celery.apps.worker import Worker |  | ||||||
| from django.core.management.base import BaseCommand |  | ||||||
| from django.db import close_old_connections |  | ||||||
| from structlog.stdlib import get_logger |  | ||||||
|  |  | ||||||
| from authentik.lib.config import CONFIG |  | ||||||
| from authentik.lib.debug import start_debug_server |  | ||||||
| from authentik.root.celery import CELERY_APP |  | ||||||
|  |  | ||||||
| LOGGER = get_logger() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Command(BaseCommand): |  | ||||||
|     """Run worker""" |  | ||||||
|  |  | ||||||
|     def add_arguments(self, parser): |  | ||||||
|         parser.add_argument( |  | ||||||
|             "-b", |  | ||||||
|             "--beat", |  | ||||||
|             action="store_false", |  | ||||||
|             help="When set, this worker will _not_ run Beat (scheduled) tasks", |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def handle(self, **options): |  | ||||||
|         LOGGER.debug("Celery options", **options) |  | ||||||
|         close_old_connections() |  | ||||||
|         start_debug_server() |  | ||||||
|         worker: Worker = CELERY_APP.Worker( |  | ||||||
|             no_color=False, |  | ||||||
|             quiet=True, |  | ||||||
|             optimization="fair", |  | ||||||
|             autoscale=(CONFIG.get_int("worker.concurrency"), 1), |  | ||||||
|             task_events=True, |  | ||||||
|             beat=options.get("beat", True), |  | ||||||
|             schedule_filename=f"{tempdir}/celerybeat-schedule", |  | ||||||
|             queues=["authentik", "authentik_scheduled", "authentik_events"], |  | ||||||
|         ) |  | ||||||
|         for task in CELERY_APP.tasks: |  | ||||||
|             LOGGER.debug("Registered task", task=task) |  | ||||||
|  |  | ||||||
|         worker.start() |  | ||||||
|         sysexit(worker.exitcode) |  | ||||||
| @ -3,6 +3,9 @@ | |||||||
| from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||||
|  |  | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
|  | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from django_dramatiq_postgres.middleware import CurrentTask | ||||||
|  | from dramatiq.actor import actor | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.models import ( | from authentik.core.models import ( | ||||||
| @ -11,17 +14,14 @@ from authentik.core.models import ( | |||||||
|     ExpiringModel, |     ExpiringModel, | ||||||
|     User, |     User, | ||||||
| ) | ) | ||||||
| from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task | from authentik.tasks.models import Task | ||||||
| from authentik.root.celery import CELERY_APP |  | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=SystemTask) | @actor(description=_("Remove expired objects.")) | ||||||
| @prefill_task | def clean_expired_models(): | ||||||
| def clean_expired_models(self: SystemTask): |     self: Task = CurrentTask.get_task() | ||||||
|     """Remove expired objects""" |  | ||||||
|     messages = [] |  | ||||||
|     for cls in ExpiringModel.__subclasses__(): |     for cls in ExpiringModel.__subclasses__(): | ||||||
|         cls: ExpiringModel |         cls: ExpiringModel | ||||||
|         objects = ( |         objects = ( | ||||||
| @ -31,16 +31,13 @@ def clean_expired_models(self: SystemTask): | |||||||
|         for obj in objects: |         for obj in objects: | ||||||
|             obj.expire_action() |             obj.expire_action() | ||||||
|         LOGGER.debug("Expired models", model=cls, amount=amount) |         LOGGER.debug("Expired models", model=cls, amount=amount) | ||||||
|         messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}") |         self.info(f"Expired {amount} {cls._meta.verbose_name_plural}") | ||||||
|     self.set_status(TaskStatus.SUCCESSFUL, *messages) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=SystemTask) | @actor(description=_("Remove temporary users created by SAML Sources.")) | ||||||
| @prefill_task | def clean_temporary_users(): | ||||||
| def clean_temporary_users(self: SystemTask): |     self: Task = CurrentTask.get_task() | ||||||
|     """Remove temporary users created by SAML Sources""" |  | ||||||
|     _now = datetime.now() |     _now = datetime.now() | ||||||
|     messages = [] |  | ||||||
|     deleted_users = 0 |     deleted_users = 0 | ||||||
|     for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}): |     for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}): | ||||||
|         if not user.attributes.get(USER_ATTRIBUTE_EXPIRES): |         if not user.attributes.get(USER_ATTRIBUTE_EXPIRES): | ||||||
| @ -52,5 +49,4 @@ def clean_temporary_users(self: SystemTask): | |||||||
|             LOGGER.debug("User is expired and will be deleted.", user=user, delta=delta) |             LOGGER.debug("User is expired and will be deleted.", user=user, delta=delta) | ||||||
|             user.delete() |             user.delete() | ||||||
|             deleted_users += 1 |             deleted_users += 1 | ||||||
|     messages.append(f"Successfully deleted {deleted_users} users.") |     self.info(f"Successfully deleted {deleted_users} users.") | ||||||
|     self.set_status(TaskStatus.SUCCESSFUL, *messages) |  | ||||||
|  | |||||||
| @ -36,7 +36,7 @@ class TestTasks(APITestCase): | |||||||
|             expires=now(), user=get_anonymous_user(), intent=TokenIntents.INTENT_API |             expires=now(), user=get_anonymous_user(), intent=TokenIntents.INTENT_API | ||||||
|         ) |         ) | ||||||
|         key = token.key |         key = token.key | ||||||
|         clean_expired_models.delay().get() |         clean_expired_models.send() | ||||||
|         token.refresh_from_db() |         token.refresh_from_db() | ||||||
|         self.assertNotEqual(key, token.key) |         self.assertNotEqual(key, token.key) | ||||||
|  |  | ||||||
| @ -50,5 +50,5 @@ class TestTasks(APITestCase): | |||||||
|                 USER_ATTRIBUTE_EXPIRES: mktime(now().timetuple()), |                 USER_ATTRIBUTE_EXPIRES: mktime(now().timetuple()), | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|         clean_temporary_users.delay().get() |         clean_temporary_users.send() | ||||||
|         self.assertFalse(User.objects.filter(username=username)) |         self.assertFalse(User.objects.filter(username=username)) | ||||||
|  | |||||||
| @ -4,6 +4,8 @@ from datetime import UTC, datetime | |||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
|  | from authentik.tasks.schedules.lib import ScheduleSpec | ||||||
|  |  | ||||||
| MANAGED_KEY = "goauthentik.io/crypto/jwt-managed" | MANAGED_KEY = "goauthentik.io/crypto/jwt-managed" | ||||||
|  |  | ||||||
| @ -67,3 +69,14 @@ class AuthentikCryptoConfig(ManagedAppConfig): | |||||||
|                 "key_data": builder.private_key, |                 "key_data": builder.private_key, | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def tenant_schedule_specs(self) -> list[ScheduleSpec]: | ||||||
|  |         from authentik.crypto.tasks import certificate_discovery | ||||||
|  |  | ||||||
|  |         return [ | ||||||
|  |             ScheduleSpec( | ||||||
|  |                 actor=certificate_discovery, | ||||||
|  |                 crontab=f"{fqdn_rand('crypto_certificate_discovery')} * * * *", | ||||||
|  |             ), | ||||||
|  |         ] | ||||||
|  | |||||||
| @ -1,13 +0,0 @@ | |||||||
| """Crypto task Settings""" |  | ||||||
|  |  | ||||||
| from celery.schedules import crontab |  | ||||||
|  |  | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
|  |  | ||||||
| CELERY_BEAT_SCHEDULE = { |  | ||||||
|     "crypto_certificate_discovery": { |  | ||||||
|         "task": "authentik.crypto.tasks.certificate_discovery", |  | ||||||
|         "schedule": crontab(minute=fqdn_rand("crypto_certificate_discovery"), hour="*"), |  | ||||||
|         "options": {"queue": "authentik_scheduled"}, |  | ||||||
|     }, |  | ||||||
| } |  | ||||||
| @ -7,13 +7,13 @@ from cryptography.hazmat.backends import default_backend | |||||||
| from cryptography.hazmat.primitives.serialization import load_pem_private_key | from cryptography.hazmat.primitives.serialization import load_pem_private_key | ||||||
| from cryptography.x509.base import load_pem_x509_certificate | from cryptography.x509.base import load_pem_x509_certificate | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from django_dramatiq_postgres.middleware import CurrentTask | ||||||
|  | from dramatiq.actor import actor | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.events.models import TaskStatus |  | ||||||
| from authentik.events.system_tasks import SystemTask, prefill_task |  | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.tasks.models import Task | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| @ -36,10 +36,9 @@ def ensure_certificate_valid(body: str): | |||||||
|     return body |     return body | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=SystemTask) | @actor(description=_("Discover, import and update certificates from the filesystem.")) | ||||||
| @prefill_task | def certificate_discovery(): | ||||||
| def certificate_discovery(self: SystemTask): |     self: Task = CurrentTask.get_task() | ||||||
|     """Discover, import and update certificates from the filesystem""" |  | ||||||
|     certs = {} |     certs = {} | ||||||
|     private_keys = {} |     private_keys = {} | ||||||
|     discovered = 0 |     discovered = 0 | ||||||
| @ -84,6 +83,4 @@ def certificate_discovery(self: SystemTask): | |||||||
|                 dirty = True |                 dirty = True | ||||||
|         if dirty: |         if dirty: | ||||||
|             cert.save() |             cert.save() | ||||||
|     self.set_status( |     self.info(f"Successfully imported {discovered} files.") | ||||||
|         TaskStatus.SUCCESSFUL, _("Successfully imported {count} files.".format(count=discovered)) |  | ||||||
|     ) |  | ||||||
|  | |||||||
| @ -338,7 +338,7 @@ class TestCrypto(APITestCase): | |||||||
|             with open(f"{temp_dir}/foo.bar/privkey.pem", "w+", encoding="utf-8") as _key: |             with open(f"{temp_dir}/foo.bar/privkey.pem", "w+", encoding="utf-8") as _key: | ||||||
|                 _key.write(builder.private_key) |                 _key.write(builder.private_key) | ||||||
|             with CONFIG.patch("cert_discovery_dir", temp_dir): |             with CONFIG.patch("cert_discovery_dir", temp_dir): | ||||||
|                 certificate_discovery() |                 certificate_discovery.send() | ||||||
|         keypair: CertificateKeyPair = CertificateKeyPair.objects.filter( |         keypair: CertificateKeyPair = CertificateKeyPair.objects.filter( | ||||||
|             managed=MANAGED_DISCOVERED % "foo" |             managed=MANAGED_DISCOVERED % "foo" | ||||||
|         ).first() |         ).first() | ||||||
|  | |||||||
| @ -3,6 +3,8 @@ | |||||||
| from django.conf import settings | from django.conf import settings | ||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
|  | from authentik.tasks.schedules.lib import ScheduleSpec | ||||||
|  |  | ||||||
|  |  | ||||||
| class EnterpriseConfig(ManagedAppConfig): | class EnterpriseConfig(ManagedAppConfig): | ||||||
| @ -26,3 +28,14 @@ class AuthentikEnterpriseConfig(EnterpriseConfig): | |||||||
|         from authentik.enterprise.license import LicenseKey |         from authentik.enterprise.license import LicenseKey | ||||||
|  |  | ||||||
|         return LicenseKey.cached_summary().status.is_valid |         return LicenseKey.cached_summary().status.is_valid | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def tenant_schedule_specs(self) -> list[ScheduleSpec]: | ||||||
|  |         from authentik.enterprise.tasks import enterprise_update_usage | ||||||
|  |  | ||||||
|  |         return [ | ||||||
|  |             ScheduleSpec( | ||||||
|  |                 actor=enterprise_update_usage, | ||||||
|  |                 crontab=f"{fqdn_rand('enterprise_update_usage')} */2 * * *", | ||||||
|  |             ), | ||||||
|  |         ] | ||||||
|  | |||||||
| @ -1,6 +1,8 @@ | |||||||
| """authentik Unique Password policy app config""" | """authentik Unique Password policy app config""" | ||||||
|  |  | ||||||
| from authentik.enterprise.apps import EnterpriseConfig | from authentik.enterprise.apps import EnterpriseConfig | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
|  | from authentik.tasks.schedules.lib import ScheduleSpec | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig): | class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig): | ||||||
| @ -8,3 +10,21 @@ class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig): | |||||||
|     label = "authentik_policies_unique_password" |     label = "authentik_policies_unique_password" | ||||||
|     verbose_name = "authentik Enterprise.Policies.Unique Password" |     verbose_name = "authentik Enterprise.Policies.Unique Password" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def tenant_schedule_specs(self) -> list[ScheduleSpec]: | ||||||
|  |         from authentik.enterprise.policies.unique_password.tasks import ( | ||||||
|  |             check_and_purge_password_history, | ||||||
|  |             trim_password_histories, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         return [ | ||||||
|  |             ScheduleSpec( | ||||||
|  |                 actor=trim_password_histories, | ||||||
|  |                 crontab=f"{fqdn_rand('policies_unique_password_trim')} */12 * * *", | ||||||
|  |             ), | ||||||
|  |             ScheduleSpec( | ||||||
|  |                 actor=check_and_purge_password_history, | ||||||
|  |                 crontab=f"{fqdn_rand('policies_unique_password_purge')} */24 * * *", | ||||||
|  |             ), | ||||||
|  |         ] | ||||||
|  | |||||||
| @ -1,20 +0,0 @@ | |||||||
| """Unique Password Policy settings""" |  | ||||||
|  |  | ||||||
| from celery.schedules import crontab |  | ||||||
|  |  | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
|  |  | ||||||
| CELERY_BEAT_SCHEDULE = { |  | ||||||
|     "policies_unique_password_trim_history": { |  | ||||||
|         "task": "authentik.enterprise.policies.unique_password.tasks.trim_password_histories", |  | ||||||
|         "schedule": crontab(minute=fqdn_rand("policies_unique_password_trim"), hour="*/12"), |  | ||||||
|         "options": {"queue": "authentik_scheduled"}, |  | ||||||
|     }, |  | ||||||
|     "policies_unique_password_check_purge": { |  | ||||||
|         "task": ( |  | ||||||
|             "authentik.enterprise.policies.unique_password.tasks.check_and_purge_password_history" |  | ||||||
|         ), |  | ||||||
|         "schedule": crontab(minute=fqdn_rand("policies_unique_password_purge"), hour="*/24"), |  | ||||||
|         "options": {"queue": "authentik_scheduled"}, |  | ||||||
|     }, |  | ||||||
| } |  | ||||||
| @ -1,35 +1,37 @@ | |||||||
| from django.db.models.aggregates import Count | from django.db.models.aggregates import Count | ||||||
|  | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from django_dramatiq_postgres.middleware import CurrentTask | ||||||
|  | from dramatiq.actor import actor | ||||||
| from structlog import get_logger | from structlog import get_logger | ||||||
|  |  | ||||||
| from authentik.enterprise.policies.unique_password.models import ( | from authentik.enterprise.policies.unique_password.models import ( | ||||||
|     UniquePasswordPolicy, |     UniquePasswordPolicy, | ||||||
|     UserPasswordHistory, |     UserPasswordHistory, | ||||||
| ) | ) | ||||||
| from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task | from authentik.tasks.models import Task | ||||||
| from authentik.root.celery import CELERY_APP |  | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=SystemTask) | @actor( | ||||||
| @prefill_task |     description=_( | ||||||
| def check_and_purge_password_history(self: SystemTask): |         "Check if any UniquePasswordPolicy exists, and if not, purge the password history table." | ||||||
|     """Check if any UniquePasswordPolicy exists, and if not, purge the password history table. |     ) | ||||||
|     This is run on a schedule instead of being triggered by policy binding deletion. | ) | ||||||
|     """ | def check_and_purge_password_history(): | ||||||
|  |     self: Task = CurrentTask.get_task() | ||||||
|  |  | ||||||
|     if not UniquePasswordPolicy.objects.exists(): |     if not UniquePasswordPolicy.objects.exists(): | ||||||
|         UserPasswordHistory.objects.all().delete() |         UserPasswordHistory.objects.all().delete() | ||||||
|         LOGGER.debug("Purged UserPasswordHistory table as no policies are in use") |         LOGGER.debug("Purged UserPasswordHistory table as no policies are in use") | ||||||
|         self.set_status(TaskStatus.SUCCESSFUL, "Successfully purged UserPasswordHistory") |         self.info("Successfully purged UserPasswordHistory") | ||||||
|         return |         return | ||||||
|  |  | ||||||
|     self.set_status( |     self.info("Not purging password histories, a unique password policy exists") | ||||||
|         TaskStatus.SUCCESSFUL, "Not purging password histories, a unique password policy exists" |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=SystemTask) | @actor(description=_("Remove user password history that are too old.")) | ||||||
| def trim_password_histories(self: SystemTask): | def trim_password_histories(): | ||||||
|     """Removes rows from UserPasswordHistory older than |     """Removes rows from UserPasswordHistory older than | ||||||
|     the `n` most recent entries. |     the `n` most recent entries. | ||||||
|  |  | ||||||
| @ -37,6 +39,8 @@ def trim_password_histories(self: SystemTask): | |||||||
|     UniquePasswordPolicy policies. |     UniquePasswordPolicy policies. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|  |     self: Task = CurrentTask.get_task() | ||||||
|  |  | ||||||
|     # No policy, we'll let the cleanup above do its thing |     # No policy, we'll let the cleanup above do its thing | ||||||
|     if not UniquePasswordPolicy.objects.exists(): |     if not UniquePasswordPolicy.objects.exists(): | ||||||
|         return |         return | ||||||
| @ -63,4 +67,4 @@ def trim_password_histories(self: SystemTask): | |||||||
|  |  | ||||||
|     num_deleted, _ = UserPasswordHistory.objects.exclude(pk__in=all_pks_to_keep).delete() |     num_deleted, _ = UserPasswordHistory.objects.exclude(pk__in=all_pks_to_keep).delete() | ||||||
|     LOGGER.debug("Deleted stale password history records", count=num_deleted) |     LOGGER.debug("Deleted stale password history records", count=num_deleted) | ||||||
|     self.set_status(TaskStatus.SUCCESSFUL, f"Delete {num_deleted} stale password history records") |     self.info(f"Delete {num_deleted} stale password history records") | ||||||
|  | |||||||
| @ -76,7 +76,7 @@ class TestCheckAndPurgePasswordHistory(TestCase): | |||||||
|         self.assertTrue(UserPasswordHistory.objects.exists()) |         self.assertTrue(UserPasswordHistory.objects.exists()) | ||||||
|  |  | ||||||
|         # Run the task - should purge since no policy is in use |         # Run the task - should purge since no policy is in use | ||||||
|         check_and_purge_password_history() |         check_and_purge_password_history.send() | ||||||
|  |  | ||||||
|         # Verify the table is empty |         # Verify the table is empty | ||||||
|         self.assertFalse(UserPasswordHistory.objects.exists()) |         self.assertFalse(UserPasswordHistory.objects.exists()) | ||||||
| @ -99,7 +99,7 @@ class TestCheckAndPurgePasswordHistory(TestCase): | |||||||
|         self.assertTrue(UserPasswordHistory.objects.exists()) |         self.assertTrue(UserPasswordHistory.objects.exists()) | ||||||
|  |  | ||||||
|         # Run the task - should NOT purge since a policy is in use |         # Run the task - should NOT purge since a policy is in use | ||||||
|         check_and_purge_password_history() |         check_and_purge_password_history.send() | ||||||
|  |  | ||||||
|         # Verify the entries still exist |         # Verify the entries still exist | ||||||
|         self.assertTrue(UserPasswordHistory.objects.exists()) |         self.assertTrue(UserPasswordHistory.objects.exists()) | ||||||
| @ -142,7 +142,7 @@ class TestTrimPasswordHistory(TestCase): | |||||||
|             enabled=True, |             enabled=True, | ||||||
|             order=0, |             order=0, | ||||||
|         ) |         ) | ||||||
|         trim_password_histories.delay() |         trim_password_histories.send() | ||||||
|         user_pwd_history_qs = UserPasswordHistory.objects.filter(user=self.user) |         user_pwd_history_qs = UserPasswordHistory.objects.filter(user=self.user) | ||||||
|         self.assertEqual(len(user_pwd_history_qs), 1) |         self.assertEqual(len(user_pwd_history_qs), 1) | ||||||
|  |  | ||||||
| @ -159,7 +159,7 @@ class TestTrimPasswordHistory(TestCase): | |||||||
|             enabled=False, |             enabled=False, | ||||||
|             order=0, |             order=0, | ||||||
|         ) |         ) | ||||||
|         trim_password_histories.delay() |         trim_password_histories.send() | ||||||
|         self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists()) |         self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists()) | ||||||
|  |  | ||||||
|     def test_trim_password_history_fewer_records_than_maximum_is_no_op(self): |     def test_trim_password_history_fewer_records_than_maximum_is_no_op(self): | ||||||
| @ -174,5 +174,5 @@ class TestTrimPasswordHistory(TestCase): | |||||||
|             enabled=True, |             enabled=True, | ||||||
|             order=0, |             order=0, | ||||||
|         ) |         ) | ||||||
|         trim_password_histories.delay() |         trim_password_histories.send() | ||||||
|         self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists()) |         self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists()) | ||||||
|  | |||||||
| @ -55,5 +55,5 @@ class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixi | |||||||
|     ] |     ] | ||||||
|     search_fields = ["name"] |     search_fields = ["name"] | ||||||
|     ordering = ["name"] |     ordering = ["name"] | ||||||
|     sync_single_task = google_workspace_sync |     sync_task = google_workspace_sync | ||||||
|     sync_objects_task = google_workspace_sync_objects |     sync_objects_task = google_workspace_sync_objects | ||||||
|  | |||||||
| @ -7,6 +7,7 @@ from django.db import models | |||||||
| from django.db.models import QuerySet | from django.db.models import QuerySet | ||||||
| from django.templatetags.static import static | from django.templatetags.static import static | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from dramatiq.actor import Actor | ||||||
| from google.oauth2.service_account import Credentials | from google.oauth2.service_account import Credentials | ||||||
| from rest_framework.serializers import Serializer | from rest_framework.serializers import Serializer | ||||||
|  |  | ||||||
| @ -110,6 +111,12 @@ class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider): | |||||||
|         help_text=_("Property mappings used for group creation/updating."), |         help_text=_("Property mappings used for group creation/updating."), | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def sync_actor(self) -> Actor: | ||||||
|  |         from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync | ||||||
|  |  | ||||||
|  |         return google_workspace_sync | ||||||
|  |  | ||||||
|     def client_for_model( |     def client_for_model( | ||||||
|         self, |         self, | ||||||
|         model: type[User | Group | GoogleWorkspaceProviderUser | GoogleWorkspaceProviderGroup], |         model: type[User | Group | GoogleWorkspaceProviderUser | GoogleWorkspaceProviderGroup], | ||||||
|  | |||||||
| @ -1,13 +0,0 @@ | |||||||
| """Google workspace provider task Settings""" |  | ||||||
|  |  | ||||||
| from celery.schedules import crontab |  | ||||||
|  |  | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
|  |  | ||||||
| CELERY_BEAT_SCHEDULE = { |  | ||||||
|     "providers_google_workspace_sync": { |  | ||||||
|         "task": "authentik.enterprise.providers.google_workspace.tasks.google_workspace_sync_all", |  | ||||||
|         "schedule": crontab(minute=fqdn_rand("google_workspace_sync_all"), hour="*/4"), |  | ||||||
|         "options": {"queue": "authentik_scheduled"}, |  | ||||||
|     }, |  | ||||||
| } |  | ||||||
| @ -2,15 +2,13 @@ | |||||||
|  |  | ||||||
| from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider | from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider | ||||||
| from authentik.enterprise.providers.google_workspace.tasks import ( | from authentik.enterprise.providers.google_workspace.tasks import ( | ||||||
|     google_workspace_sync, |     google_workspace_sync_direct_dispatch, | ||||||
|     google_workspace_sync_direct, |     google_workspace_sync_m2m_dispatch, | ||||||
|     google_workspace_sync_m2m, |  | ||||||
| ) | ) | ||||||
| from authentik.lib.sync.outgoing.signals import register_signals | from authentik.lib.sync.outgoing.signals import register_signals | ||||||
|  |  | ||||||
| register_signals( | register_signals( | ||||||
|     GoogleWorkspaceProvider, |     GoogleWorkspaceProvider, | ||||||
|     task_sync_single=google_workspace_sync, |     task_sync_direct_dispatch=google_workspace_sync_direct_dispatch, | ||||||
|     task_sync_direct=google_workspace_sync_direct, |     task_sync_m2m_dispatch=google_workspace_sync_m2m_dispatch, | ||||||
|     task_sync_m2m=google_workspace_sync_m2m, |  | ||||||
| ) | ) | ||||||
|  | |||||||
| @ -1,37 +1,48 @@ | |||||||
| """Google Provider tasks""" | """Google Provider tasks""" | ||||||
|  |  | ||||||
|  | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from dramatiq.actor import actor | ||||||
|  |  | ||||||
| from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider | from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider | ||||||
| from authentik.events.system_tasks import SystemTask |  | ||||||
| from authentik.lib.sync.outgoing.exceptions import TransientSyncException |  | ||||||
| from authentik.lib.sync.outgoing.tasks import SyncTasks | from authentik.lib.sync.outgoing.tasks import SyncTasks | ||||||
| from authentik.root.celery import CELERY_APP |  | ||||||
|  |  | ||||||
| sync_tasks = SyncTasks(GoogleWorkspaceProvider) | sync_tasks = SyncTasks(GoogleWorkspaceProvider) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | @actor(description=_("Sync Google Workspace provider objects.")) | ||||||
| def google_workspace_sync_objects(*args, **kwargs): | def google_workspace_sync_objects(*args, **kwargs): | ||||||
|     return sync_tasks.sync_objects(*args, **kwargs) |     return sync_tasks.sync_objects(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task( | @actor(description=_("Full sync for Google Workspace provider.")) | ||||||
|     base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True | def google_workspace_sync(provider_pk: int, *args, **kwargs): | ||||||
| ) |  | ||||||
| def google_workspace_sync(self, provider_pk: int, *args, **kwargs): |  | ||||||
|     """Run full sync for Google Workspace provider""" |     """Run full sync for Google Workspace provider""" | ||||||
|     return sync_tasks.sync_single(self, provider_pk, google_workspace_sync_objects) |     return sync_tasks.sync(provider_pk, google_workspace_sync_objects) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task() | @actor(description=_("Sync a direct object (user, group) for Google Workspace provider.")) | ||||||
| def google_workspace_sync_all(): |  | ||||||
|     return sync_tasks.sync_all(google_workspace_sync) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) |  | ||||||
| def google_workspace_sync_direct(*args, **kwargs): | def google_workspace_sync_direct(*args, **kwargs): | ||||||
|     return sync_tasks.sync_signal_direct(*args, **kwargs) |     return sync_tasks.sync_signal_direct(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | @actor( | ||||||
|  |     description=_( | ||||||
|  |         "Dispatch syncs for a direct object (user, group) for Google Workspace providers." | ||||||
|  |     ) | ||||||
|  | ) | ||||||
|  | def google_workspace_sync_direct_dispatch(*args, **kwargs): | ||||||
|  |     return sync_tasks.sync_signal_direct_dispatch(google_workspace_sync_direct, *args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @actor(description=_("Sync a related object (memberships) for Google Workspace provider.")) | ||||||
| def google_workspace_sync_m2m(*args, **kwargs): | def google_workspace_sync_m2m(*args, **kwargs): | ||||||
|     return sync_tasks.sync_signal_m2m(*args, **kwargs) |     return sync_tasks.sync_signal_m2m(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @actor( | ||||||
|  |     description=_( | ||||||
|  |         "Dispatch syncs for a related object (memberships) for Google Workspace providers." | ||||||
|  |     ) | ||||||
|  | ) | ||||||
|  | def google_workspace_sync_m2m_dispatch(*args, **kwargs): | ||||||
|  |     return sync_tasks.sync_signal_m2m_dispatch(google_workspace_sync_m2m, *args, **kwargs) | ||||||
|  | |||||||
| @ -324,7 +324,7 @@ class GoogleWorkspaceGroupTests(TestCase): | |||||||
|             "authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials", |             "authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials", | ||||||
|             MagicMock(return_value={"developerKey": self.api_key, "http": http}), |             MagicMock(return_value={"developerKey": self.api_key, "http": http}), | ||||||
|         ): |         ): | ||||||
|             google_workspace_sync.delay(self.provider.pk).get() |             google_workspace_sync.send(self.provider.pk).get_result() | ||||||
|             self.assertTrue( |             self.assertTrue( | ||||||
|                 GoogleWorkspaceProviderGroup.objects.filter( |                 GoogleWorkspaceProviderGroup.objects.filter( | ||||||
|                     group=different_group, provider=self.provider |                     group=different_group, provider=self.provider | ||||||
|  | |||||||
| @ -302,7 +302,7 @@ class GoogleWorkspaceUserTests(TestCase): | |||||||
|             "authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials", |             "authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials", | ||||||
|             MagicMock(return_value={"developerKey": self.api_key, "http": http}), |             MagicMock(return_value={"developerKey": self.api_key, "http": http}), | ||||||
|         ): |         ): | ||||||
|             google_workspace_sync.delay(self.provider.pk).get() |             google_workspace_sync.send(self.provider.pk).get_result() | ||||||
|             self.assertTrue( |             self.assertTrue( | ||||||
|                 GoogleWorkspaceProviderUser.objects.filter( |                 GoogleWorkspaceProviderUser.objects.filter( | ||||||
|                     user=different_user, provider=self.provider |                     user=different_user, provider=self.provider | ||||||
|  | |||||||
| @ -53,5 +53,5 @@ class MicrosoftEntraProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin | |||||||
|     ] |     ] | ||||||
|     search_fields = ["name"] |     search_fields = ["name"] | ||||||
|     ordering = ["name"] |     ordering = ["name"] | ||||||
|     sync_single_task = microsoft_entra_sync |     sync_task = microsoft_entra_sync | ||||||
|     sync_objects_task = microsoft_entra_sync_objects |     sync_objects_task = microsoft_entra_sync_objects | ||||||
|  | |||||||
| @ -8,6 +8,7 @@ from django.db import models | |||||||
| from django.db.models import QuerySet | from django.db.models import QuerySet | ||||||
| from django.templatetags.static import static | from django.templatetags.static import static | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from dramatiq.actor import Actor | ||||||
| from rest_framework.serializers import Serializer | from rest_framework.serializers import Serializer | ||||||
|  |  | ||||||
| from authentik.core.models import ( | from authentik.core.models import ( | ||||||
| @ -99,6 +100,12 @@ class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider): | |||||||
|         help_text=_("Property mappings used for group creation/updating."), |         help_text=_("Property mappings used for group creation/updating."), | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def sync_actor(self) -> Actor: | ||||||
|  |         from authentik.enterprise.providers.microsoft_entra.tasks import microsoft_entra_sync | ||||||
|  |  | ||||||
|  |         return microsoft_entra_sync | ||||||
|  |  | ||||||
|     def client_for_model( |     def client_for_model( | ||||||
|         self, |         self, | ||||||
|         model: type[User | Group | MicrosoftEntraProviderUser | MicrosoftEntraProviderGroup], |         model: type[User | Group | MicrosoftEntraProviderUser | MicrosoftEntraProviderGroup], | ||||||
|  | |||||||
| @ -1,13 +0,0 @@ | |||||||
| """Microsoft Entra provider task Settings""" |  | ||||||
|  |  | ||||||
| from celery.schedules import crontab |  | ||||||
|  |  | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
|  |  | ||||||
| CELERY_BEAT_SCHEDULE = { |  | ||||||
|     "providers_microsoft_entra_sync": { |  | ||||||
|         "task": "authentik.enterprise.providers.microsoft_entra.tasks.microsoft_entra_sync_all", |  | ||||||
|         "schedule": crontab(minute=fqdn_rand("microsoft_entra_sync_all"), hour="*/4"), |  | ||||||
|         "options": {"queue": "authentik_scheduled"}, |  | ||||||
|     }, |  | ||||||
| } |  | ||||||
| @ -2,15 +2,13 @@ | |||||||
|  |  | ||||||
| from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider | from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider | ||||||
| from authentik.enterprise.providers.microsoft_entra.tasks import ( | from authentik.enterprise.providers.microsoft_entra.tasks import ( | ||||||
|     microsoft_entra_sync, |     microsoft_entra_sync_direct_dispatch, | ||||||
|     microsoft_entra_sync_direct, |     microsoft_entra_sync_m2m_dispatch, | ||||||
|     microsoft_entra_sync_m2m, |  | ||||||
| ) | ) | ||||||
| from authentik.lib.sync.outgoing.signals import register_signals | from authentik.lib.sync.outgoing.signals import register_signals | ||||||
|  |  | ||||||
| register_signals( | register_signals( | ||||||
|     MicrosoftEntraProvider, |     MicrosoftEntraProvider, | ||||||
|     task_sync_single=microsoft_entra_sync, |     task_sync_direct_dispatch=microsoft_entra_sync_direct_dispatch, | ||||||
|     task_sync_direct=microsoft_entra_sync_direct, |     task_sync_m2m_dispatch=microsoft_entra_sync_m2m_dispatch, | ||||||
|     task_sync_m2m=microsoft_entra_sync_m2m, |  | ||||||
| ) | ) | ||||||
|  | |||||||
| @ -1,37 +1,46 @@ | |||||||
| """Microsoft Entra Provider tasks""" | """Microsoft Entra Provider tasks""" | ||||||
|  |  | ||||||
|  | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from dramatiq.actor import actor | ||||||
|  |  | ||||||
| from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider | from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider | ||||||
| from authentik.events.system_tasks import SystemTask |  | ||||||
| from authentik.lib.sync.outgoing.exceptions import TransientSyncException |  | ||||||
| from authentik.lib.sync.outgoing.tasks import SyncTasks | from authentik.lib.sync.outgoing.tasks import SyncTasks | ||||||
| from authentik.root.celery import CELERY_APP |  | ||||||
|  |  | ||||||
| sync_tasks = SyncTasks(MicrosoftEntraProvider) | sync_tasks = SyncTasks(MicrosoftEntraProvider) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | @actor(description=_("Sync Microsoft Entra provider objects.")) | ||||||
| def microsoft_entra_sync_objects(*args, **kwargs): | def microsoft_entra_sync_objects(*args, **kwargs): | ||||||
|     return sync_tasks.sync_objects(*args, **kwargs) |     return sync_tasks.sync_objects(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task( | @actor(description=_("Full sync for Microsoft Entra provider.")) | ||||||
|     base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True | def microsoft_entra_sync(provider_pk: int, *args, **kwargs): | ||||||
| ) |  | ||||||
| def microsoft_entra_sync(self, provider_pk: int, *args, **kwargs): |  | ||||||
|     """Run full sync for Microsoft Entra provider""" |     """Run full sync for Microsoft Entra provider""" | ||||||
|     return sync_tasks.sync_single(self, provider_pk, microsoft_entra_sync_objects) |     return sync_tasks.sync(provider_pk, microsoft_entra_sync_objects) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task() | @actor(description=_("Sync a direct object (user, group) for Microsoft Entra provider.")) | ||||||
| def microsoft_entra_sync_all(): |  | ||||||
|     return sync_tasks.sync_all(microsoft_entra_sync) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) |  | ||||||
| def microsoft_entra_sync_direct(*args, **kwargs): | def microsoft_entra_sync_direct(*args, **kwargs): | ||||||
|     return sync_tasks.sync_signal_direct(*args, **kwargs) |     return sync_tasks.sync_signal_direct(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | @actor( | ||||||
|  |     description=_("Dispatch syncs for a direct object (user, group) for Microsoft Entra providers.") | ||||||
|  | ) | ||||||
|  | def microsoft_entra_sync_direct_dispatch(*args, **kwargs): | ||||||
|  |     return sync_tasks.sync_signal_direct_dispatch(microsoft_entra_sync_direct, *args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @actor(description=_("Sync a related object (memberships) for Microsoft Entra provider.")) | ||||||
| def microsoft_entra_sync_m2m(*args, **kwargs): | def microsoft_entra_sync_m2m(*args, **kwargs): | ||||||
|     return sync_tasks.sync_signal_m2m(*args, **kwargs) |     return sync_tasks.sync_signal_m2m(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @actor( | ||||||
|  |     description=_( | ||||||
|  |         "Dispatch syncs for a related object (memberships) for Microsoft Entra providers." | ||||||
|  |     ) | ||||||
|  | ) | ||||||
|  | def microsoft_entra_sync_m2m_dispatch(*args, **kwargs): | ||||||
|  |     return sync_tasks.sync_signal_m2m_dispatch(microsoft_entra_sync_m2m, *args, **kwargs) | ||||||
|  | |||||||
| @ -252,9 +252,13 @@ class MicrosoftEntraGroupTests(TestCase): | |||||||
|             member_add.assert_called_once() |             member_add.assert_called_once() | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 member_add.call_args[0][0].odata_id, |                 member_add.call_args[0][0].odata_id, | ||||||
|                 f"https://graph.microsoft.com/v1.0/directoryObjects/{MicrosoftEntraProviderUser.objects.filter( |                 f"https://graph.microsoft.com/v1.0/directoryObjects/{ | ||||||
|  |                     MicrosoftEntraProviderUser.objects.filter( | ||||||
|                         provider=self.provider, |                         provider=self.provider, | ||||||
|                     ).first().microsoft_id}", |                     ) | ||||||
|  |                     .first() | ||||||
|  |                     .microsoft_id | ||||||
|  |                 }", | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def test_group_create_member_remove(self): |     def test_group_create_member_remove(self): | ||||||
| @ -311,9 +315,13 @@ class MicrosoftEntraGroupTests(TestCase): | |||||||
|             member_add.assert_called_once() |             member_add.assert_called_once() | ||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 member_add.call_args[0][0].odata_id, |                 member_add.call_args[0][0].odata_id, | ||||||
|                 f"https://graph.microsoft.com/v1.0/directoryObjects/{MicrosoftEntraProviderUser.objects.filter( |                 f"https://graph.microsoft.com/v1.0/directoryObjects/{ | ||||||
|  |                     MicrosoftEntraProviderUser.objects.filter( | ||||||
|                         provider=self.provider, |                         provider=self.provider, | ||||||
|                     ).first().microsoft_id}", |                     ) | ||||||
|  |                     .first() | ||||||
|  |                     .microsoft_id | ||||||
|  |                 }", | ||||||
|             ) |             ) | ||||||
|             member_remove.assert_called_once() |             member_remove.assert_called_once() | ||||||
|  |  | ||||||
| @ -413,7 +421,7 @@ class MicrosoftEntraGroupTests(TestCase): | |||||||
|                 ), |                 ), | ||||||
|             ) as group_list, |             ) as group_list, | ||||||
|         ): |         ): | ||||||
|             microsoft_entra_sync.delay(self.provider.pk).get() |             microsoft_entra_sync.send(self.provider.pk).get_result() | ||||||
|             self.assertTrue( |             self.assertTrue( | ||||||
|                 MicrosoftEntraProviderGroup.objects.filter( |                 MicrosoftEntraProviderGroup.objects.filter( | ||||||
|                     group=different_group, provider=self.provider |                     group=different_group, provider=self.provider | ||||||
|  | |||||||
| @ -397,7 +397,7 @@ class MicrosoftEntraUserTests(APITestCase): | |||||||
|                 AsyncMock(return_value=GroupCollectionResponse(value=[])), |                 AsyncMock(return_value=GroupCollectionResponse(value=[])), | ||||||
|             ), |             ), | ||||||
|         ): |         ): | ||||||
|             microsoft_entra_sync.delay(self.provider.pk).get() |             microsoft_entra_sync.send(self.provider.pk).get_result() | ||||||
|             self.assertTrue( |             self.assertTrue( | ||||||
|                 MicrosoftEntraProviderUser.objects.filter( |                 MicrosoftEntraProviderUser.objects.filter( | ||||||
|                     user=different_user, provider=self.provider |                     user=different_user, provider=self.provider | ||||||
|  | |||||||
| @ -17,6 +17,7 @@ from authentik.crypto.models import CertificateKeyPair | |||||||
| from authentik.lib.models import CreatedUpdatedModel | from authentik.lib.models import CreatedUpdatedModel | ||||||
| from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator | from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator | ||||||
| from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider | from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider | ||||||
|  | from authentik.tasks.models import TasksModel | ||||||
|  |  | ||||||
|  |  | ||||||
| class EventTypes(models.TextChoices): | class EventTypes(models.TextChoices): | ||||||
| @ -42,7 +43,7 @@ class SSFEventStatus(models.TextChoices): | |||||||
|     SENT = "sent" |     SENT = "sent" | ||||||
|  |  | ||||||
|  |  | ||||||
| class SSFProvider(BackchannelProvider): | class SSFProvider(TasksModel, BackchannelProvider): | ||||||
|     """Shared Signals Framework provider to allow applications to |     """Shared Signals Framework provider to allow applications to | ||||||
|     receive user events from authentik.""" |     receive user events from authentik.""" | ||||||
|  |  | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ from authentik.enterprise.providers.ssf.models import ( | |||||||
|     EventTypes, |     EventTypes, | ||||||
|     SSFProvider, |     SSFProvider, | ||||||
| ) | ) | ||||||
| from authentik.enterprise.providers.ssf.tasks import send_ssf_event | from authentik.enterprise.providers.ssf.tasks import send_ssf_events | ||||||
| from authentik.events.middleware import audit_ignore | from authentik.events.middleware import audit_ignore | ||||||
| from authentik.stages.authenticator.models import Device | from authentik.stages.authenticator.models import Device | ||||||
| from authentik.stages.authenticator_duo.models import DuoDevice | from authentik.stages.authenticator_duo.models import DuoDevice | ||||||
| @ -66,7 +66,7 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi | |||||||
|  |  | ||||||
|     As this signal is also triggered with a regular logout, we can't be sure |     As this signal is also triggered with a regular logout, we can't be sure | ||||||
|     if the session has been deleted by an admin or by the user themselves.""" |     if the session has been deleted by an admin or by the user themselves.""" | ||||||
|     send_ssf_event( |     send_ssf_events( | ||||||
|         EventTypes.CAEP_SESSION_REVOKED, |         EventTypes.CAEP_SESSION_REVOKED, | ||||||
|         { |         { | ||||||
|             "initiating_entity": "user", |             "initiating_entity": "user", | ||||||
| @ -88,7 +88,7 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi | |||||||
| @receiver(password_changed) | @receiver(password_changed) | ||||||
| def ssf_password_changed_cred_change(sender, user: User, password: str | None, **_): | def ssf_password_changed_cred_change(sender, user: User, password: str | None, **_): | ||||||
|     """Credential change trigger (password changed)""" |     """Credential change trigger (password changed)""" | ||||||
|     send_ssf_event( |     send_ssf_events( | ||||||
|         EventTypes.CAEP_CREDENTIAL_CHANGE, |         EventTypes.CAEP_CREDENTIAL_CHANGE, | ||||||
|         { |         { | ||||||
|             "credential_type": "password", |             "credential_type": "password", | ||||||
| @ -126,7 +126,7 @@ def ssf_device_post_save(sender: type[Model], instance: Device, created: bool, * | |||||||
|     } |     } | ||||||
|     if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID: |     if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID: | ||||||
|         data["fido2_aaguid"] = instance.aaguid |         data["fido2_aaguid"] = instance.aaguid | ||||||
|     send_ssf_event( |     send_ssf_events( | ||||||
|         EventTypes.CAEP_CREDENTIAL_CHANGE, |         EventTypes.CAEP_CREDENTIAL_CHANGE, | ||||||
|         data, |         data, | ||||||
|         sub_id={ |         sub_id={ | ||||||
| @ -153,7 +153,7 @@ def ssf_device_post_delete(sender: type[Model], instance: Device, **_): | |||||||
|     } |     } | ||||||
|     if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID: |     if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID: | ||||||
|         data["fido2_aaguid"] = instance.aaguid |         data["fido2_aaguid"] = instance.aaguid | ||||||
|     send_ssf_event( |     send_ssf_events( | ||||||
|         EventTypes.CAEP_CREDENTIAL_CHANGE, |         EventTypes.CAEP_CREDENTIAL_CHANGE, | ||||||
|         data, |         data, | ||||||
|         sub_id={ |         sub_id={ | ||||||
|  | |||||||
| @ -1,7 +1,11 @@ | |||||||
| from celery import group | from typing import Any | ||||||
|  | from uuid import UUID | ||||||
|  |  | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from django_dramatiq_postgres.middleware import CurrentTask | ||||||
|  | from dramatiq.actor import actor | ||||||
| from requests.exceptions import RequestException | from requests.exceptions import RequestException | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| @ -13,19 +17,16 @@ from authentik.enterprise.providers.ssf.models import ( | |||||||
|     Stream, |     Stream, | ||||||
|     StreamEvent, |     StreamEvent, | ||||||
| ) | ) | ||||||
| from authentik.events.logs import LogEvent |  | ||||||
| from authentik.events.models import TaskStatus |  | ||||||
| from authentik.events.system_tasks import SystemTask |  | ||||||
| from authentik.lib.utils.http import get_http_session | from authentik.lib.utils.http import get_http_session | ||||||
| from authentik.lib.utils.time import timedelta_from_string | from authentik.lib.utils.time import timedelta_from_string | ||||||
| from authentik.policies.engine import PolicyEngine | from authentik.policies.engine import PolicyEngine | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.tasks.models import Task | ||||||
|  |  | ||||||
| session = get_http_session() | session = get_http_session() | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| def send_ssf_event( | def send_ssf_events( | ||||||
|     event_type: EventTypes, |     event_type: EventTypes, | ||||||
|     data: dict, |     data: dict, | ||||||
|     stream_filter: dict | None = None, |     stream_filter: dict | None = None, | ||||||
| @ -33,7 +34,7 @@ def send_ssf_event( | |||||||
|     **extra_data, |     **extra_data, | ||||||
| ): | ): | ||||||
|     """Wrapper to send an SSF event to multiple streams""" |     """Wrapper to send an SSF event to multiple streams""" | ||||||
|     payload = [] |     events_data = {} | ||||||
|     if not stream_filter: |     if not stream_filter: | ||||||
|         stream_filter = {} |         stream_filter = {} | ||||||
|     stream_filter["events_requested__contains"] = [event_type] |     stream_filter["events_requested__contains"] = [event_type] | ||||||
| @ -41,16 +42,22 @@ def send_ssf_event( | |||||||
|         extra_data.setdefault("txn", request.request_id) |         extra_data.setdefault("txn", request.request_id) | ||||||
|     for stream in Stream.objects.filter(**stream_filter): |     for stream in Stream.objects.filter(**stream_filter): | ||||||
|         event_data = stream.prepare_event_payload(event_type, data, **extra_data) |         event_data = stream.prepare_event_payload(event_type, data, **extra_data) | ||||||
|         payload.append((str(stream.uuid), event_data)) |         events_data[stream.uuid] = event_data | ||||||
|     return _send_ssf_event.delay(payload) |     ssf_events_dispatch.send(events_data) | ||||||
|  |  | ||||||
|  |  | ||||||
| def _check_app_access(stream_uuid: str, event_data: dict) -> bool: | @actor(description=_("Dispatch SSF events.")) | ||||||
|     """Check if event is related to user and if so, check | def ssf_events_dispatch(events_data: dict[str, dict[str, Any]]): | ||||||
|     if the user has access to the application""" |     for stream_uuid, event_data in events_data.items(): | ||||||
|         stream = Stream.objects.filter(pk=stream_uuid).first() |         stream = Stream.objects.filter(pk=stream_uuid).first() | ||||||
|         if not stream: |         if not stream: | ||||||
|         return False |             continue | ||||||
|  |         send_ssf_event.send_with_options(args=(stream_uuid, event_data), rel_obj=stream.provider) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _check_app_access(stream: Stream, event_data: dict) -> bool: | ||||||
|  |     """Check if event is related to user and if so, check | ||||||
|  |     if the user has access to the application""" | ||||||
|     # `event_data` is a dict version of a StreamEvent |     # `event_data` is a dict version of a StreamEvent | ||||||
|     sub_id = event_data.get("payload", {}).get("sub_id", {}) |     sub_id = event_data.get("payload", {}).get("sub_id", {}) | ||||||
|     email = sub_id.get("user", {}).get("email", None) |     email = sub_id.get("user", {}).get("email", None) | ||||||
| @ -65,42 +72,22 @@ def _check_app_access(stream_uuid: str, event_data: dict) -> bool: | |||||||
|     return engine.passing |     return engine.passing | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task() | @actor(description=_("Send an SSF event.")) | ||||||
| def _send_ssf_event(event_data: list[tuple[str, dict]]): | def send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]): | ||||||
|     tasks = [] |     self: Task = CurrentTask.get_task() | ||||||
|     for stream, data in event_data: |  | ||||||
|         if not _check_app_access(stream, data): |  | ||||||
|             continue |  | ||||||
|         event = StreamEvent.objects.create(**data) |  | ||||||
|         tasks.extend(send_single_ssf_event(stream, str(event.uuid))) |  | ||||||
|     main_task = group(*tasks) |  | ||||||
|     main_task() |  | ||||||
|  |  | ||||||
|  |     stream = Stream.objects.filter(pk=stream_uuid).first() | ||||||
| def send_single_ssf_event(stream_id: str, evt_id: str): |  | ||||||
|     stream = Stream.objects.filter(pk=stream_id).first() |  | ||||||
|     if not stream: |     if not stream: | ||||||
|         return |         return | ||||||
|     event = StreamEvent.objects.filter(pk=evt_id).first() |     if not _check_app_access(stream, event_data): | ||||||
|     if not event: |  | ||||||
|         return |         return | ||||||
|  |     event = StreamEvent.objects.create(**event_data) | ||||||
|  |     self.set_uid(event.pk) | ||||||
|     if event.status == SSFEventStatus.SENT: |     if event.status == SSFEventStatus.SENT: | ||||||
|         return |         return | ||||||
|     if stream.delivery_method == DeliveryMethods.RISC_PUSH: |     if stream.delivery_method != DeliveryMethods.RISC_PUSH: | ||||||
|         return [ssf_push_event.si(str(event.pk))] |  | ||||||
|     return [] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=SystemTask) |  | ||||||
| def ssf_push_event(self: SystemTask, event_id: str): |  | ||||||
|     self.save_on_success = False |  | ||||||
|     event = StreamEvent.objects.filter(pk=event_id).first() |  | ||||||
|     if not event: |  | ||||||
|         return |  | ||||||
|     self.set_uid(event_id) |  | ||||||
|     if event.status == SSFEventStatus.SENT: |  | ||||||
|         self.set_status(TaskStatus.SUCCESSFUL) |  | ||||||
|         return |         return | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         response = session.post( |         response = session.post( | ||||||
|             event.stream.endpoint_url, |             event.stream.endpoint_url, | ||||||
| @ -110,26 +97,17 @@ def ssf_push_event(self: SystemTask, event_id: str): | |||||||
|         response.raise_for_status() |         response.raise_for_status() | ||||||
|         event.status = SSFEventStatus.SENT |         event.status = SSFEventStatus.SENT | ||||||
|         event.save() |         event.save() | ||||||
|         self.set_status(TaskStatus.SUCCESSFUL) |  | ||||||
|         return |         return | ||||||
|     except RequestException as exc: |     except RequestException as exc: | ||||||
|         LOGGER.warning("Failed to send SSF event", exc=exc) |         LOGGER.warning("Failed to send SSF event", exc=exc) | ||||||
|         self.set_status(TaskStatus.ERROR) |  | ||||||
|         attrs = {} |         attrs = {} | ||||||
|         if exc.response: |         if exc.response: | ||||||
|             attrs["response"] = { |             attrs["response"] = { | ||||||
|                 "content": exc.response.text, |                 "content": exc.response.text, | ||||||
|                 "status": exc.response.status_code, |                 "status": exc.response.status_code, | ||||||
|             } |             } | ||||||
|         self.set_error( |         self.warning(exc) | ||||||
|             exc, |         self.warning("Failed to send request", **attrs) | ||||||
|             LogEvent( |  | ||||||
|                 _("Failed to send request"), |  | ||||||
|                 log_level="warning", |  | ||||||
|                 logger=self.__name__, |  | ||||||
|                 attributes=attrs, |  | ||||||
|             ), |  | ||||||
|         ) |  | ||||||
|         # Re-up the expiry of the stream event |         # Re-up the expiry of the stream event | ||||||
|         event.expires = now() + timedelta_from_string(event.stream.provider.event_retention) |         event.expires = now() + timedelta_from_string(event.stream.provider.event_retention) | ||||||
|         event.status = SSFEventStatus.PENDING_FAILED |         event.status = SSFEventStatus.PENDING_FAILED | ||||||
|  | |||||||
| @ -13,7 +13,7 @@ from authentik.enterprise.providers.ssf.models import ( | |||||||
|     SSFProvider, |     SSFProvider, | ||||||
|     Stream, |     Stream, | ||||||
| ) | ) | ||||||
| from authentik.enterprise.providers.ssf.tasks import send_ssf_event | from authentik.enterprise.providers.ssf.tasks import send_ssf_events | ||||||
| from authentik.enterprise.providers.ssf.views.base import SSFView | from authentik.enterprise.providers.ssf.views.base import SSFView | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -109,7 +109,7 @@ class StreamView(SSFView): | |||||||
|                 "User does not have permission to create stream for this provider." |                 "User does not have permission to create stream for this provider." | ||||||
|             ) |             ) | ||||||
|         instance: Stream = stream.save(provider=self.provider) |         instance: Stream = stream.save(provider=self.provider) | ||||||
|         send_ssf_event( |         send_ssf_events( | ||||||
|             EventTypes.SET_VERIFICATION, |             EventTypes.SET_VERIFICATION, | ||||||
|             { |             { | ||||||
|                 "state": None, |                 "state": None, | ||||||
|  | |||||||
| @ -1,17 +1,5 @@ | |||||||
| """Enterprise additional settings""" | """Enterprise additional settings""" | ||||||
|  |  | ||||||
| from celery.schedules import crontab |  | ||||||
|  |  | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
|  |  | ||||||
| CELERY_BEAT_SCHEDULE = { |  | ||||||
|     "enterprise_update_usage": { |  | ||||||
|         "task": "authentik.enterprise.tasks.enterprise_update_usage", |  | ||||||
|         "schedule": crontab(minute=fqdn_rand("enterprise_update_usage"), hour="*/2"), |  | ||||||
|         "options": {"queue": "authentik_scheduled"}, |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| TENANT_APPS = [ | TENANT_APPS = [ | ||||||
|     "authentik.enterprise.audit", |     "authentik.enterprise.audit", | ||||||
|     "authentik.enterprise.policies.unique_password", |     "authentik.enterprise.policies.unique_password", | ||||||
|  | |||||||
| @ -10,6 +10,7 @@ from django.utils.timezone import get_current_timezone | |||||||
| from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE | from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE | ||||||
| from authentik.enterprise.models import License | from authentik.enterprise.models import License | ||||||
| from authentik.enterprise.tasks import enterprise_update_usage | from authentik.enterprise.tasks import enterprise_update_usage | ||||||
|  | from authentik.tasks.schedules.models import Schedule | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_save, sender=License) | @receiver(pre_save, sender=License) | ||||||
| @ -26,7 +27,7 @@ def pre_save_license(sender: type[License], instance: License, **_): | |||||||
| def post_save_license(sender: type[License], instance: License, **_): | def post_save_license(sender: type[License], instance: License, **_): | ||||||
|     """Trigger license usage calculation when license is saved""" |     """Trigger license usage calculation when license is saved""" | ||||||
|     cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) |     cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) | ||||||
|     enterprise_update_usage.delay() |     Schedule.dispatch_by_actor(enterprise_update_usage) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(post_delete, sender=License) | @receiver(post_delete, sender=License) | ||||||
|  | |||||||
| @ -1,14 +1,11 @@ | |||||||
| """Enterprise tasks""" | """Enterprise tasks""" | ||||||
|  |  | ||||||
|  | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from dramatiq.actor import actor | ||||||
|  |  | ||||||
| from authentik.enterprise.license import LicenseKey | from authentik.enterprise.license import LicenseKey | ||||||
| from authentik.events.models import TaskStatus |  | ||||||
| from authentik.events.system_tasks import SystemTask, prefill_task |  | ||||||
| from authentik.root.celery import CELERY_APP |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=SystemTask) | @actor(description=_("Update enterprise license status.")) | ||||||
| @prefill_task | def enterprise_update_usage(): | ||||||
| def enterprise_update_usage(self: SystemTask): |  | ||||||
|     """Update enterprise license status""" |  | ||||||
|     LicenseKey.get_total().record_usage() |     LicenseKey.get_total().record_usage() | ||||||
|     self.set_status(TaskStatus.SUCCESSFUL) |  | ||||||
|  | |||||||
| @ -1,104 +0,0 @@ | |||||||
| """Tasks API""" |  | ||||||
|  |  | ||||||
| from importlib import import_module |  | ||||||
|  |  | ||||||
| from django.contrib import messages |  | ||||||
| from django.utils.translation import gettext_lazy as _ |  | ||||||
| from drf_spectacular.types import OpenApiTypes |  | ||||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema |  | ||||||
| from rest_framework.decorators import action |  | ||||||
| from rest_framework.fields import ( |  | ||||||
|     CharField, |  | ||||||
|     ChoiceField, |  | ||||||
|     DateTimeField, |  | ||||||
|     FloatField, |  | ||||||
|     SerializerMethodField, |  | ||||||
| ) |  | ||||||
| from rest_framework.request import Request |  | ||||||
| from rest_framework.response import Response |  | ||||||
| from rest_framework.viewsets import ReadOnlyModelViewSet |  | ||||||
| from structlog.stdlib import get_logger |  | ||||||
|  |  | ||||||
| from authentik.core.api.utils import ModelSerializer |  | ||||||
| from authentik.events.logs import LogEventSerializer |  | ||||||
| from authentik.events.models import SystemTask, TaskStatus |  | ||||||
| from authentik.rbac.decorators import permission_required |  | ||||||
|  |  | ||||||
| LOGGER = get_logger() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SystemTaskSerializer(ModelSerializer): |  | ||||||
|     """Serialize TaskInfo and TaskResult""" |  | ||||||
|  |  | ||||||
|     name = CharField() |  | ||||||
|     full_name = SerializerMethodField() |  | ||||||
|     uid = CharField(required=False) |  | ||||||
|     description = CharField() |  | ||||||
|     start_timestamp = DateTimeField(read_only=True) |  | ||||||
|     finish_timestamp = DateTimeField(read_only=True) |  | ||||||
|     duration = FloatField(read_only=True) |  | ||||||
|  |  | ||||||
|     status = ChoiceField(choices=[(x.value, x.name) for x in TaskStatus]) |  | ||||||
|     messages = LogEventSerializer(many=True) |  | ||||||
|  |  | ||||||
|     def get_full_name(self, instance: SystemTask) -> str: |  | ||||||
|         """Get full name with UID""" |  | ||||||
|         if instance.uid: |  | ||||||
|             return f"{instance.name}:{instance.uid}" |  | ||||||
|         return instance.name |  | ||||||
|  |  | ||||||
|     class Meta: |  | ||||||
|         model = SystemTask |  | ||||||
|         fields = [ |  | ||||||
|             "uuid", |  | ||||||
|             "name", |  | ||||||
|             "full_name", |  | ||||||
|             "uid", |  | ||||||
|             "description", |  | ||||||
|             "start_timestamp", |  | ||||||
|             "finish_timestamp", |  | ||||||
|             "duration", |  | ||||||
|             "status", |  | ||||||
|             "messages", |  | ||||||
|             "expires", |  | ||||||
|             "expiring", |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SystemTaskViewSet(ReadOnlyModelViewSet): |  | ||||||
|     """Read-only view set that returns all background tasks""" |  | ||||||
|  |  | ||||||
|     queryset = SystemTask.objects.all() |  | ||||||
|     serializer_class = SystemTaskSerializer |  | ||||||
|     filterset_fields = ["name", "uid", "status"] |  | ||||||
|     ordering = ["name", "uid", "status"] |  | ||||||
|     search_fields = ["name", "description", "uid", "status"] |  | ||||||
|  |  | ||||||
|     @permission_required(None, ["authentik_events.run_task"]) |  | ||||||
|     @extend_schema( |  | ||||||
|         request=OpenApiTypes.NONE, |  | ||||||
|         responses={ |  | ||||||
|             204: OpenApiResponse(description="Task retried successfully"), |  | ||||||
|             404: OpenApiResponse(description="Task not found"), |  | ||||||
|             500: OpenApiResponse(description="Failed to retry task"), |  | ||||||
|         }, |  | ||||||
|     ) |  | ||||||
|     @action(detail=True, methods=["POST"], permission_classes=[]) |  | ||||||
|     def run(self, request: Request, pk=None) -> Response: |  | ||||||
|         """Run task""" |  | ||||||
|         task: SystemTask = self.get_object() |  | ||||||
|         try: |  | ||||||
|             task_module = import_module(task.task_call_module) |  | ||||||
|             task_func = getattr(task_module, task.task_call_func) |  | ||||||
|             LOGGER.info("Running task", task=task_func) |  | ||||||
|             task_func.delay(*task.task_call_args, **task.task_call_kwargs) |  | ||||||
|             messages.success( |  | ||||||
|                 self.request, |  | ||||||
|                 _("Successfully started task {name}.".format_map({"name": task.name})), |  | ||||||
|             ) |  | ||||||
|             return Response(status=204) |  | ||||||
|         except (ImportError, AttributeError) as exc:  # pragma: no cover |  | ||||||
|             LOGGER.warning("Failed to run task, remove state", task=task.name, exc=exc) |  | ||||||
|             # if we get an import error, the module path has probably changed |  | ||||||
|             task.delete() |  | ||||||
|             return Response(status=500) |  | ||||||
| @ -1,12 +1,11 @@ | |||||||
| """authentik events app""" | """authentik events app""" | ||||||
|  |  | ||||||
| from celery.schedules import crontab |  | ||||||
| from prometheus_client import Gauge, Histogram | from prometheus_client import Gauge, Histogram | ||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
| from authentik.lib.config import CONFIG, ENV_PREFIX | from authentik.lib.config import CONFIG, ENV_PREFIX | ||||||
| from authentik.lib.utils.reflection import path_to_class | from authentik.lib.utils.time import fqdn_rand | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.tasks.schedules.lib import ScheduleSpec | ||||||
|  |  | ||||||
| # TODO: Deprecated metric - remove in 2024.2 or later | # TODO: Deprecated metric - remove in 2024.2 or later | ||||||
| GAUGE_TASKS = Gauge( | GAUGE_TASKS = Gauge( | ||||||
| @ -35,6 +34,17 @@ class AuthentikEventsConfig(ManagedAppConfig): | |||||||
|     verbose_name = "authentik Events" |     verbose_name = "authentik Events" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def tenant_schedule_specs(self) -> list[ScheduleSpec]: | ||||||
|  |         from authentik.events.tasks import notification_cleanup | ||||||
|  |  | ||||||
|  |         return [ | ||||||
|  |             ScheduleSpec( | ||||||
|  |                 actor=notification_cleanup, | ||||||
|  |                 crontab=f"{fqdn_rand('notification_cleanup')} */8 * * *", | ||||||
|  |             ), | ||||||
|  |         ] | ||||||
|  |  | ||||||
|     @ManagedAppConfig.reconcile_global |     @ManagedAppConfig.reconcile_global | ||||||
|     def check_deprecations(self): |     def check_deprecations(self): | ||||||
|         """Check for config deprecations""" |         """Check for config deprecations""" | ||||||
| @ -56,41 +66,3 @@ class AuthentikEventsConfig(ManagedAppConfig): | |||||||
|                 replacement_env=replace_env, |                 replacement_env=replace_env, | ||||||
|                 message=msg, |                 message=msg, | ||||||
|             ).save() |             ).save() | ||||||
|  |  | ||||||
|     @ManagedAppConfig.reconcile_tenant |  | ||||||
|     def prefill_tasks(self): |  | ||||||
|         """Prefill tasks""" |  | ||||||
|         from authentik.events.models import SystemTask |  | ||||||
|         from authentik.events.system_tasks import _prefill_tasks |  | ||||||
|  |  | ||||||
|         for task in _prefill_tasks: |  | ||||||
|             if SystemTask.objects.filter(name=task.name).exists(): |  | ||||||
|                 continue |  | ||||||
|             task.save() |  | ||||||
|             self.logger.debug("prefilled task", task_name=task.name) |  | ||||||
|  |  | ||||||
|     @ManagedAppConfig.reconcile_tenant |  | ||||||
|     def run_scheduled_tasks(self): |  | ||||||
|         """Run schedule tasks which are behind schedule (only applies |  | ||||||
|         to tasks of which we keep metrics)""" |  | ||||||
|         from authentik.events.models import TaskStatus |  | ||||||
|         from authentik.events.system_tasks import SystemTask as CelerySystemTask |  | ||||||
|  |  | ||||||
|         for task in CELERY_APP.conf["beat_schedule"].values(): |  | ||||||
|             schedule = task["schedule"] |  | ||||||
|             if not isinstance(schedule, crontab): |  | ||||||
|                 continue |  | ||||||
|             task_class: CelerySystemTask = path_to_class(task["task"]) |  | ||||||
|             if not isinstance(task_class, CelerySystemTask): |  | ||||||
|                 continue |  | ||||||
|             db_task = task_class.db() |  | ||||||
|             if not db_task: |  | ||||||
|                 continue |  | ||||||
|             due, _ = schedule.is_due(db_task.finish_timestamp) |  | ||||||
|             if due or db_task.status == TaskStatus.UNKNOWN: |  | ||||||
|                 self.logger.debug("Running past-due scheduled task", task=task["task"]) |  | ||||||
|                 task_class.apply_async( |  | ||||||
|                     args=task.get("args", None), |  | ||||||
|                     kwargs=task.get("kwargs", None), |  | ||||||
|                     **task.get("options", {}), |  | ||||||
|                 ) |  | ||||||
|  | |||||||
							
								
								
									
										22
									
								
								authentik/events/migrations/0011_alter_systemtask_options.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								authentik/events/migrations/0011_alter_systemtask_options.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,22 @@ | |||||||
|  | # Generated by Django 5.1.11 on 2025-06-24 15:36 | ||||||
|  |  | ||||||
|  | from django.db import migrations | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_events", "0010_rename_group_notificationrule_destination_group_and_more"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.AlterModelOptions( | ||||||
|  |             name="systemtask", | ||||||
|  |             options={ | ||||||
|  |                 "default_permissions": (), | ||||||
|  |                 "permissions": (), | ||||||
|  |                 "verbose_name": "System Task", | ||||||
|  |                 "verbose_name_plural": "System Tasks", | ||||||
|  |             }, | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -5,12 +5,11 @@ from datetime import timedelta | |||||||
| from difflib import get_close_matches | from difflib import get_close_matches | ||||||
| from functools import lru_cache | from functools import lru_cache | ||||||
| from inspect import currentframe | from inspect import currentframe | ||||||
| from smtplib import SMTPException |  | ||||||
| from typing import Any | from typing import Any | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from django.apps import apps | from django.apps import apps | ||||||
| from django.db import connection, models | from django.db import models | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| from django.http.request import QueryDict | from django.http.request import QueryDict | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| @ -27,7 +26,6 @@ from authentik.core.middleware import ( | |||||||
|     SESSION_KEY_IMPERSONATE_USER, |     SESSION_KEY_IMPERSONATE_USER, | ||||||
| ) | ) | ||||||
| from authentik.core.models import ExpiringModel, Group, PropertyMapping, User | from authentik.core.models import ExpiringModel, Group, PropertyMapping, User | ||||||
| from authentik.events.apps import GAUGE_TASKS, SYSTEM_TASK_STATUS, SYSTEM_TASK_TIME |  | ||||||
| from authentik.events.context_processors.base import get_context_processors | from authentik.events.context_processors.base import get_context_processors | ||||||
| from authentik.events.utils import ( | from authentik.events.utils import ( | ||||||
|     cleanse_dict, |     cleanse_dict, | ||||||
| @ -43,6 +41,7 @@ from authentik.lib.utils.time import timedelta_from_string | |||||||
| from authentik.policies.models import PolicyBindingModel | from authentik.policies.models import PolicyBindingModel | ||||||
| from authentik.root.middleware import ClientIPMiddleware | from authentik.root.middleware import ClientIPMiddleware | ||||||
| from authentik.stages.email.utils import TemplateEmailMessage | from authentik.stages.email.utils import TemplateEmailMessage | ||||||
|  | from authentik.tasks.models import TasksModel | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
| from authentik.tenants.utils import get_current_tenant | from authentik.tenants.utils import get_current_tenant | ||||||
|  |  | ||||||
| @ -267,7 +266,8 @@ class Event(SerializerModel, ExpiringModel): | |||||||
|             models.Index(fields=["created"]), |             models.Index(fields=["created"]), | ||||||
|             models.Index(fields=["client_ip"]), |             models.Index(fields=["client_ip"]), | ||||||
|             models.Index( |             models.Index( | ||||||
|                 models.F("context__authorized_application"), name="authentik_e_ctx_app__idx" |                 models.F("context__authorized_application"), | ||||||
|  |                 name="authentik_e_ctx_app__idx", | ||||||
|             ), |             ), | ||||||
|         ] |         ] | ||||||
|  |  | ||||||
| @ -281,7 +281,7 @@ class TransportMode(models.TextChoices): | |||||||
|     EMAIL = "email", _("Email") |     EMAIL = "email", _("Email") | ||||||
|  |  | ||||||
|  |  | ||||||
| class NotificationTransport(SerializerModel): | class NotificationTransport(TasksModel, SerializerModel): | ||||||
|     """Action which is executed when a Rule matches""" |     """Action which is executed when a Rule matches""" | ||||||
|  |  | ||||||
|     uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) |     uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) | ||||||
| @ -446,6 +446,8 @@ class NotificationTransport(SerializerModel): | |||||||
|  |  | ||||||
|     def send_email(self, notification: "Notification") -> list[str]: |     def send_email(self, notification: "Notification") -> list[str]: | ||||||
|         """Send notification via global email configuration""" |         """Send notification via global email configuration""" | ||||||
|  |         from authentik.stages.email.tasks import send_mail | ||||||
|  |  | ||||||
|         if notification.user.email.strip() == "": |         if notification.user.email.strip() == "": | ||||||
|             LOGGER.info( |             LOGGER.info( | ||||||
|                 "Discarding notification as user has no email address", |                 "Discarding notification as user has no email address", | ||||||
| @ -487,17 +489,14 @@ class NotificationTransport(SerializerModel): | |||||||
|             template_name="email/event_notification.html", |             template_name="email/event_notification.html", | ||||||
|             template_context=context, |             template_context=context, | ||||||
|         ) |         ) | ||||||
|         # Email is sent directly here, as the call to send() should have been from a task. |         send_mail.send_with_options(args=(mail.__dict__,), rel_obj=self) | ||||||
|         try: |         return [] | ||||||
|             from authentik.stages.email.tasks import send_mail |  | ||||||
|  |  | ||||||
|             return send_mail(mail.__dict__) |  | ||||||
|         except (SMTPException, ConnectionError, OSError) as exc: |  | ||||||
|             raise NotificationTransportError(exc) from exc |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> type[Serializer]: |     def serializer(self) -> type[Serializer]: | ||||||
|         from authentik.events.api.notification_transports import NotificationTransportSerializer |         from authentik.events.api.notification_transports import ( | ||||||
|  |             NotificationTransportSerializer, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         return NotificationTransportSerializer |         return NotificationTransportSerializer | ||||||
|  |  | ||||||
| @ -547,7 +546,7 @@ class Notification(SerializerModel): | |||||||
|         verbose_name_plural = _("Notifications") |         verbose_name_plural = _("Notifications") | ||||||
|  |  | ||||||
|  |  | ||||||
| class NotificationRule(SerializerModel, PolicyBindingModel): | class NotificationRule(TasksModel, SerializerModel, PolicyBindingModel): | ||||||
|     """Decide when to create a Notification based on policies attached to this object.""" |     """Decide when to create a Notification based on policies attached to this object.""" | ||||||
|  |  | ||||||
|     name = models.TextField(unique=True) |     name = models.TextField(unique=True) | ||||||
| @ -611,7 +610,9 @@ class NotificationWebhookMapping(PropertyMapping): | |||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> type[type[Serializer]]: |     def serializer(self) -> type[type[Serializer]]: | ||||||
|         from authentik.events.api.notification_mappings import NotificationWebhookMappingSerializer |         from authentik.events.api.notification_mappings import ( | ||||||
|  |             NotificationWebhookMappingSerializer, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         return NotificationWebhookMappingSerializer |         return NotificationWebhookMappingSerializer | ||||||
|  |  | ||||||
| @ -624,7 +625,7 @@ class NotificationWebhookMapping(PropertyMapping): | |||||||
|  |  | ||||||
|  |  | ||||||
| class TaskStatus(models.TextChoices): | class TaskStatus(models.TextChoices): | ||||||
|     """Possible states of tasks""" |     """DEPRECATED do not use""" | ||||||
|  |  | ||||||
|     UNKNOWN = "unknown" |     UNKNOWN = "unknown" | ||||||
|     SUCCESSFUL = "successful" |     SUCCESSFUL = "successful" | ||||||
| @ -632,8 +633,8 @@ class TaskStatus(models.TextChoices): | |||||||
|     ERROR = "error" |     ERROR = "error" | ||||||
|  |  | ||||||
|  |  | ||||||
| class SystemTask(SerializerModel, ExpiringModel): | class SystemTask(ExpiringModel): | ||||||
|     """Info about a system task running in the background along with details to restart the task""" |     """DEPRECATED do not use""" | ||||||
|  |  | ||||||
|     uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) |     uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) | ||||||
|     name = models.TextField() |     name = models.TextField() | ||||||
| @ -653,41 +654,13 @@ class SystemTask(SerializerModel, ExpiringModel): | |||||||
|     task_call_args = models.JSONField(default=list) |     task_call_args = models.JSONField(default=list) | ||||||
|     task_call_kwargs = models.JSONField(default=dict) |     task_call_kwargs = models.JSONField(default=dict) | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def serializer(self) -> type[Serializer]: |  | ||||||
|         from authentik.events.api.tasks import SystemTaskSerializer |  | ||||||
|  |  | ||||||
|         return SystemTaskSerializer |  | ||||||
|  |  | ||||||
|     def update_metrics(self): |  | ||||||
|         """Update prometheus metrics""" |  | ||||||
|         # TODO: Deprecated metric - remove in 2024.2 or later |  | ||||||
|         GAUGE_TASKS.labels( |  | ||||||
|             tenant=connection.schema_name, |  | ||||||
|             task_name=self.name, |  | ||||||
|             task_uid=self.uid or "", |  | ||||||
|             status=self.status.lower(), |  | ||||||
|         ).set(self.duration) |  | ||||||
|         SYSTEM_TASK_TIME.labels( |  | ||||||
|             tenant=connection.schema_name, |  | ||||||
|             task_name=self.name, |  | ||||||
|             task_uid=self.uid or "", |  | ||||||
|         ).observe(self.duration) |  | ||||||
|         SYSTEM_TASK_STATUS.labels( |  | ||||||
|             tenant=connection.schema_name, |  | ||||||
|             task_name=self.name, |  | ||||||
|             task_uid=self.uid or "", |  | ||||||
|             status=self.status.lower(), |  | ||||||
|         ).inc() |  | ||||||
|  |  | ||||||
|     def __str__(self) -> str: |     def __str__(self) -> str: | ||||||
|         return f"System Task {self.name}" |         return f"System Task {self.name}" | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         unique_together = (("name", "uid"),) |         unique_together = (("name", "uid"),) | ||||||
|         # Remove "add", "change" and "delete" permissions as those are not used |         default_permissions = () | ||||||
|         default_permissions = ["view"] |         permissions = () | ||||||
|         permissions = [("run_task", _("Run task"))] |  | ||||||
|         verbose_name = _("System Task") |         verbose_name = _("System Task") | ||||||
|         verbose_name_plural = _("System Tasks") |         verbose_name_plural = _("System Tasks") | ||||||
|         indexes = ExpiringModel.Meta.indexes |         indexes = ExpiringModel.Meta.indexes | ||||||
|  | |||||||
| @ -1,13 +0,0 @@ | |||||||
| """Event Settings""" |  | ||||||
|  |  | ||||||
| from celery.schedules import crontab |  | ||||||
|  |  | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
|  |  | ||||||
| CELERY_BEAT_SCHEDULE = { |  | ||||||
|     "events_notification_cleanup": { |  | ||||||
|         "task": "authentik.events.tasks.notification_cleanup", |  | ||||||
|         "schedule": crontab(minute=fqdn_rand("notification_cleanup"), hour="*/8"), |  | ||||||
|         "options": {"queue": "authentik_scheduled"}, |  | ||||||
|     }, |  | ||||||
| } |  | ||||||
| @ -12,13 +12,10 @@ from rest_framework.request import Request | |||||||
|  |  | ||||||
| from authentik.core.models import AuthenticatedSession, User | from authentik.core.models import AuthenticatedSession, User | ||||||
| from authentik.core.signals import login_failed, password_changed | from authentik.core.signals import login_failed, password_changed | ||||||
| from authentik.events.apps import SYSTEM_TASK_STATUS | from authentik.events.models import Event, EventAction | ||||||
| from authentik.events.models import Event, EventAction, SystemTask |  | ||||||
| from authentik.events.tasks import event_notification_handler, gdpr_cleanup |  | ||||||
| from authentik.flows.models import Stage | from authentik.flows.models import Stage | ||||||
| from authentik.flows.planner import PLAN_CONTEXT_OUTPOST, PLAN_CONTEXT_SOURCE, FlowPlan | from authentik.flows.planner import PLAN_CONTEXT_OUTPOST, PLAN_CONTEXT_SOURCE, FlowPlan | ||||||
| from authentik.flows.views.executor import SESSION_KEY_PLAN | from authentik.flows.views.executor import SESSION_KEY_PLAN | ||||||
| from authentik.root.monitoring import monitoring_set |  | ||||||
| from authentik.stages.invitation.models import Invitation | from authentik.stages.invitation.models import Invitation | ||||||
| from authentik.stages.invitation.signals import invitation_used | from authentik.stages.invitation.signals import invitation_used | ||||||
| from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS | from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS | ||||||
| @ -114,19 +111,15 @@ def on_password_changed(sender, user: User, password: str, request: HttpRequest | |||||||
| @receiver(post_save, sender=Event) | @receiver(post_save, sender=Event) | ||||||
| def event_post_save_notification(sender, instance: Event, **_): | def event_post_save_notification(sender, instance: Event, **_): | ||||||
|     """Start task to check if any policies trigger an notification on this event""" |     """Start task to check if any policies trigger an notification on this event""" | ||||||
|     event_notification_handler.delay(instance.event_uuid.hex) |     from authentik.events.tasks import event_trigger_dispatch | ||||||
|  |  | ||||||
|  |     event_trigger_dispatch.send(instance.event_uuid) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_delete, sender=User) | @receiver(pre_delete, sender=User) | ||||||
| def event_user_pre_delete_cleanup(sender, instance: User, **_): | def event_user_pre_delete_cleanup(sender, instance: User, **_): | ||||||
|     """If gdpr_compliance is enabled, remove all the user's events""" |     """If gdpr_compliance is enabled, remove all the user's events""" | ||||||
|  |     from authentik.events.tasks import gdpr_cleanup | ||||||
|  |  | ||||||
|     if get_current_tenant().gdpr_compliance: |     if get_current_tenant().gdpr_compliance: | ||||||
|         gdpr_cleanup.delay(instance.pk) |         gdpr_cleanup.send(instance.pk) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(monitoring_set) |  | ||||||
| def monitoring_system_task(sender, **_): |  | ||||||
|     """Update metrics when task is saved""" |  | ||||||
|     SYSTEM_TASK_STATUS.clear() |  | ||||||
|     for task in SystemTask.objects.all(): |  | ||||||
|         task.update_metrics() |  | ||||||
|  | |||||||
| @ -1,156 +0,0 @@ | |||||||
| """Monitored tasks""" |  | ||||||
|  |  | ||||||
| from datetime import datetime, timedelta |  | ||||||
| from time import perf_counter |  | ||||||
| from typing import Any |  | ||||||
|  |  | ||||||
| from django.utils.timezone import now |  | ||||||
| from django.utils.translation import gettext_lazy as _ |  | ||||||
| from structlog.stdlib import BoundLogger, get_logger |  | ||||||
| from tenant_schemas_celery.task import TenantTask |  | ||||||
|  |  | ||||||
| from authentik.events.logs import LogEvent |  | ||||||
| from authentik.events.models import Event, EventAction, TaskStatus |  | ||||||
| from authentik.events.models import SystemTask as DBSystemTask |  | ||||||
| from authentik.events.utils import sanitize_item |  | ||||||
| from authentik.lib.utils.errors import exception_to_string |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SystemTask(TenantTask): |  | ||||||
|     """Task which can save its state to the cache""" |  | ||||||
|  |  | ||||||
|     logger: BoundLogger |  | ||||||
|  |  | ||||||
|     # For tasks that should only be listed if they failed, set this to False |  | ||||||
|     save_on_success: bool |  | ||||||
|  |  | ||||||
|     _status: TaskStatus |  | ||||||
|     _messages: list[LogEvent] |  | ||||||
|  |  | ||||||
|     _uid: str | None |  | ||||||
|     # Precise start time from perf_counter |  | ||||||
|     _start_precise: float | None = None |  | ||||||
|     _start: datetime | None = None |  | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs) -> None: |  | ||||||
|         super().__init__(*args, **kwargs) |  | ||||||
|         self._status = TaskStatus.SUCCESSFUL |  | ||||||
|         self.save_on_success = True |  | ||||||
|         self._uid = None |  | ||||||
|         self._status = None |  | ||||||
|         self._messages = [] |  | ||||||
|         self.result_timeout_hours = 6 |  | ||||||
|  |  | ||||||
|     def set_uid(self, uid: str): |  | ||||||
|         """Set UID, so in the case of an unexpected error its saved correctly""" |  | ||||||
|         self._uid = uid |  | ||||||
|  |  | ||||||
|     def set_status(self, status: TaskStatus, *messages: LogEvent): |  | ||||||
|         """Set result for current run, will overwrite previous result.""" |  | ||||||
|         self._status = status |  | ||||||
|         self._messages = list(messages) |  | ||||||
|         for idx, msg in enumerate(self._messages): |  | ||||||
|             if not isinstance(msg, LogEvent): |  | ||||||
|                 self._messages[idx] = LogEvent(msg, logger=self.__name__, log_level="info") |  | ||||||
|  |  | ||||||
|     def set_error(self, exception: Exception, *messages: LogEvent): |  | ||||||
|         """Set result to error and save exception""" |  | ||||||
|         self._status = TaskStatus.ERROR |  | ||||||
|         self._messages = list(messages) |  | ||||||
|         self._messages.extend( |  | ||||||
|             [LogEvent(exception_to_string(exception), logger=self.__name__, log_level="error")] |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def before_start(self, task_id, args, kwargs): |  | ||||||
|         self._start_precise = perf_counter() |  | ||||||
|         self._start = now() |  | ||||||
|         self.logger = get_logger().bind(task_id=task_id) |  | ||||||
|         return super().before_start(task_id, args, kwargs) |  | ||||||
|  |  | ||||||
|     def db(self) -> DBSystemTask | None: |  | ||||||
|         """Get DB object for latest task""" |  | ||||||
|         return DBSystemTask.objects.filter( |  | ||||||
|             name=self.__name__, |  | ||||||
|             uid=self._uid, |  | ||||||
|         ).first() |  | ||||||
|  |  | ||||||
|     def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo): |  | ||||||
|         super().after_return(status, retval, task_id, args, kwargs, einfo=einfo) |  | ||||||
|         if not self._status: |  | ||||||
|             return |  | ||||||
|         if self._status == TaskStatus.SUCCESSFUL and not self.save_on_success: |  | ||||||
|             DBSystemTask.objects.filter( |  | ||||||
|                 name=self.__name__, |  | ||||||
|                 uid=self._uid, |  | ||||||
|             ).delete() |  | ||||||
|             return |  | ||||||
|         DBSystemTask.objects.update_or_create( |  | ||||||
|             name=self.__name__, |  | ||||||
|             uid=self._uid, |  | ||||||
|             defaults={ |  | ||||||
|                 "description": self.__doc__, |  | ||||||
|                 "start_timestamp": self._start or now(), |  | ||||||
|                 "finish_timestamp": now(), |  | ||||||
|                 "duration": max(perf_counter() - self._start_precise, 0), |  | ||||||
|                 "task_call_module": self.__module__, |  | ||||||
|                 "task_call_func": self.__name__, |  | ||||||
|                 "task_call_args": sanitize_item(args), |  | ||||||
|                 "task_call_kwargs": sanitize_item(kwargs), |  | ||||||
|                 "status": self._status, |  | ||||||
|                 "messages": sanitize_item(self._messages), |  | ||||||
|                 "expires": now() + timedelta(hours=self.result_timeout_hours), |  | ||||||
|                 "expiring": True, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def on_failure(self, exc, task_id, args, kwargs, einfo): |  | ||||||
|         super().on_failure(exc, task_id, args, kwargs, einfo=einfo) |  | ||||||
|         if not self._status: |  | ||||||
|             self.set_error(exc) |  | ||||||
|         DBSystemTask.objects.update_or_create( |  | ||||||
|             name=self.__name__, |  | ||||||
|             uid=self._uid, |  | ||||||
|             defaults={ |  | ||||||
|                 "description": self.__doc__, |  | ||||||
|                 "start_timestamp": self._start or now(), |  | ||||||
|                 "finish_timestamp": now(), |  | ||||||
|                 "duration": max(perf_counter() - self._start_precise, 0), |  | ||||||
|                 "task_call_module": self.__module__, |  | ||||||
|                 "task_call_func": self.__name__, |  | ||||||
|                 "task_call_args": sanitize_item(args), |  | ||||||
|                 "task_call_kwargs": sanitize_item(kwargs), |  | ||||||
|                 "status": self._status, |  | ||||||
|                 "messages": sanitize_item(self._messages), |  | ||||||
|                 "expires": now() + timedelta(hours=self.result_timeout_hours + 3), |  | ||||||
|                 "expiring": True, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         Event.new( |  | ||||||
|             EventAction.SYSTEM_TASK_EXCEPTION, |  | ||||||
|             message=f"Task {self.__name__} encountered an error: {exception_to_string(exc)}", |  | ||||||
|         ).save() |  | ||||||
|  |  | ||||||
|     def run(self, *args, **kwargs): |  | ||||||
|         raise NotImplementedError |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def prefill_task(func): |  | ||||||
|     """Ensure a task's details are always in cache, so it can always be triggered via API""" |  | ||||||
|     _prefill_tasks.append( |  | ||||||
|         DBSystemTask( |  | ||||||
|             name=func.__name__, |  | ||||||
|             description=func.__doc__, |  | ||||||
|             start_timestamp=now(), |  | ||||||
|             finish_timestamp=now(), |  | ||||||
|             status=TaskStatus.UNKNOWN, |  | ||||||
|             messages=sanitize_item([_("Task has not been run yet.")]), |  | ||||||
|             task_call_module=func.__module__, |  | ||||||
|             task_call_func=func.__name__, |  | ||||||
|             expiring=False, |  | ||||||
|             duration=0, |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     return func |  | ||||||
|  |  | ||||||
|  |  | ||||||
| _prefill_tasks = [] |  | ||||||
| @ -1,41 +1,49 @@ | |||||||
| """Event notification tasks""" | """Event notification tasks""" | ||||||
|  |  | ||||||
|  | from uuid import UUID | ||||||
|  |  | ||||||
| from django.db.models.query_utils import Q | from django.db.models.query_utils import Q | ||||||
|  | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from django_dramatiq_postgres.middleware import CurrentTask | ||||||
|  | from dramatiq.actor import actor | ||||||
| from guardian.shortcuts import get_anonymous_user | from guardian.shortcuts import get_anonymous_user | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.expression.exceptions import PropertyMappingExpressionException |  | ||||||
| from authentik.core.models import User | from authentik.core.models import User | ||||||
| from authentik.events.models import ( | from authentik.events.models import ( | ||||||
|     Event, |     Event, | ||||||
|     Notification, |     Notification, | ||||||
|     NotificationRule, |     NotificationRule, | ||||||
|     NotificationTransport, |     NotificationTransport, | ||||||
|     NotificationTransportError, |  | ||||||
|     TaskStatus, |  | ||||||
| ) | ) | ||||||
| from authentik.events.system_tasks import SystemTask, prefill_task |  | ||||||
| from authentik.policies.engine import PolicyEngine | from authentik.policies.engine import PolicyEngine | ||||||
| from authentik.policies.models import PolicyBinding, PolicyEngineMode | from authentik.policies.models import PolicyBinding, PolicyEngineMode | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.tasks.models import Task | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task() | @actor(description=_("Dispatch new event notifications.")) | ||||||
| def event_notification_handler(event_uuid: str): | def event_trigger_dispatch(event_uuid: UUID): | ||||||
|     """Start task for each trigger definition""" |  | ||||||
|     for trigger in NotificationRule.objects.all(): |     for trigger in NotificationRule.objects.all(): | ||||||
|         event_trigger_handler.apply_async(args=[event_uuid, trigger.name], queue="authentik_events") |         event_trigger_handler.send_with_options(args=(event_uuid, trigger.name), rel_obj=trigger) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task() | @actor( | ||||||
| def event_trigger_handler(event_uuid: str, trigger_name: str): |     description=_( | ||||||
|  |         "Check if policies attached to NotificationRule match event " | ||||||
|  |         "and dispatch notification tasks." | ||||||
|  |     ) | ||||||
|  | ) | ||||||
|  | def event_trigger_handler(event_uuid: UUID, trigger_name: str): | ||||||
|     """Check if policies attached to NotificationRule match event""" |     """Check if policies attached to NotificationRule match event""" | ||||||
|  |     self: Task = CurrentTask.get_task() | ||||||
|  |  | ||||||
|     event: Event = Event.objects.filter(event_uuid=event_uuid).first() |     event: Event = Event.objects.filter(event_uuid=event_uuid).first() | ||||||
|     if not event: |     if not event: | ||||||
|         LOGGER.warning("event doesn't exist yet or anymore", event_uuid=event_uuid) |         self.warning("event doesn't exist yet or anymore", event_uuid=event_uuid) | ||||||
|         return |         return | ||||||
|  |  | ||||||
|     trigger: NotificationRule | None = NotificationRule.objects.filter(name=trigger_name).first() |     trigger: NotificationRule | None = NotificationRule.objects.filter(name=trigger_name).first() | ||||||
|     if not trigger: |     if not trigger: | ||||||
|         return |         return | ||||||
| @ -70,34 +78,27 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): | |||||||
|  |  | ||||||
|     LOGGER.debug("e(trigger): event trigger matched", trigger=trigger) |     LOGGER.debug("e(trigger): event trigger matched", trigger=trigger) | ||||||
|     # Create the notification objects |     # Create the notification objects | ||||||
|  |     count = 0 | ||||||
|     for transport in trigger.transports.all(): |     for transport in trigger.transports.all(): | ||||||
|         for user in trigger.destination_users(event): |         for user in trigger.destination_users(event): | ||||||
|             LOGGER.debug("created notification") |             notification_transport.send_with_options( | ||||||
|             notification_transport.apply_async( |                 args=( | ||||||
|                 args=[ |  | ||||||
|                     transport.pk, |                     transport.pk, | ||||||
|                     str(event.pk), |                     event.pk, | ||||||
|                     user.pk, |                     user.pk, | ||||||
|                     str(trigger.pk), |                     trigger.pk, | ||||||
|                 ], |                 ), | ||||||
|                 queue="authentik_events", |                 rel_obj=transport, | ||||||
|             ) |             ) | ||||||
|  |             count += 1 | ||||||
|             if transport.send_once: |             if transport.send_once: | ||||||
|                 break |                 break | ||||||
|  |     self.info(f"Created {count} notification tasks") | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task( | @actor(description=_("Send notification.")) | ||||||
|     bind=True, | def notification_transport(transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str): | ||||||
|     autoretry_for=(NotificationTransportError,), |  | ||||||
|     retry_backoff=True, |  | ||||||
|     base=SystemTask, |  | ||||||
| ) |  | ||||||
| def notification_transport( |  | ||||||
|     self: SystemTask, transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str |  | ||||||
| ): |  | ||||||
|     """Send notification over specified transport""" |     """Send notification over specified transport""" | ||||||
|     self.save_on_success = False |  | ||||||
|     try: |  | ||||||
|     event = Event.objects.filter(pk=event_pk).first() |     event = Event.objects.filter(pk=event_pk).first() | ||||||
|     if not event: |     if not event: | ||||||
|         return |         return | ||||||
| @ -110,17 +111,13 @@ def notification_transport( | |||||||
|     notification = Notification( |     notification = Notification( | ||||||
|         severity=trigger.severity, body=event.summary, event=event, user=user |         severity=trigger.severity, body=event.summary, event=event, user=user | ||||||
|     ) |     ) | ||||||
|         transport = NotificationTransport.objects.filter(pk=transport_pk).first() |     transport: NotificationTransport = NotificationTransport.objects.filter(pk=transport_pk).first() | ||||||
|     if not transport: |     if not transport: | ||||||
|         return |         return | ||||||
|     transport.send(notification) |     transport.send(notification) | ||||||
|         self.set_status(TaskStatus.SUCCESSFUL) |  | ||||||
|     except (NotificationTransportError, PropertyMappingExpressionException) as exc: |  | ||||||
|         self.set_error(exc) |  | ||||||
|         raise exc |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task() | @actor(description=_("Cleanup events for GDPR compliance.")) | ||||||
| def gdpr_cleanup(user_pk: int): | def gdpr_cleanup(user_pk: int): | ||||||
|     """cleanup events from gdpr_compliance""" |     """cleanup events from gdpr_compliance""" | ||||||
|     events = Event.objects.filter(user__pk=user_pk) |     events = Event.objects.filter(user__pk=user_pk) | ||||||
| @ -128,12 +125,12 @@ def gdpr_cleanup(user_pk: int): | |||||||
|     events.delete() |     events.delete() | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=SystemTask) | @actor(description=_("Cleanup seen notifications and notifications whose event expired.")) | ||||||
| @prefill_task | def notification_cleanup(): | ||||||
| def notification_cleanup(self: SystemTask): |  | ||||||
|     """Cleanup seen notifications and notifications whose event expired.""" |     """Cleanup seen notifications and notifications whose event expired.""" | ||||||
|  |     self: Task = CurrentTask.get_task() | ||||||
|     notifications = Notification.objects.filter(Q(event=None) | Q(seen=True)) |     notifications = Notification.objects.filter(Q(event=None) | Q(seen=True)) | ||||||
|     amount = notifications.count() |     amount = notifications.count() | ||||||
|     notifications.delete() |     notifications.delete() | ||||||
|     LOGGER.debug("Expired notifications", amount=amount) |     LOGGER.debug("Expired notifications", amount=amount) | ||||||
|     self.set_status(TaskStatus.SUCCESSFUL, f"Expired {amount} Notifications") |     self.info(f"Expired {amount} Notifications") | ||||||
|  | |||||||
| @ -1,103 +0,0 @@ | |||||||
| """Test Monitored tasks""" |  | ||||||
|  |  | ||||||
| from json import loads |  | ||||||
|  |  | ||||||
| from django.urls import reverse |  | ||||||
| from rest_framework.test import APITestCase |  | ||||||
|  |  | ||||||
| from authentik.core.tasks import clean_expired_models |  | ||||||
| from authentik.core.tests.utils import create_test_admin_user |  | ||||||
| from authentik.events.models import SystemTask as DBSystemTask |  | ||||||
| from authentik.events.models import TaskStatus |  | ||||||
| from authentik.events.system_tasks import SystemTask |  | ||||||
| from authentik.lib.generators import generate_id |  | ||||||
| from authentik.root.celery import CELERY_APP |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestSystemTasks(APITestCase): |  | ||||||
|     """Test Monitored tasks""" |  | ||||||
|  |  | ||||||
|     def setUp(self): |  | ||||||
|         super().setUp() |  | ||||||
|         self.user = create_test_admin_user() |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|  |  | ||||||
|     def test_failed_successful_remove_state(self): |  | ||||||
|         """Test that a task with `save_on_success` set to `False` that failed saves |  | ||||||
|         a state, and upon successful completion will delete the state""" |  | ||||||
|         should_fail = True |  | ||||||
|         uid = generate_id() |  | ||||||
|  |  | ||||||
|         @CELERY_APP.task( |  | ||||||
|             bind=True, |  | ||||||
|             base=SystemTask, |  | ||||||
|         ) |  | ||||||
|         def test_task(self: SystemTask): |  | ||||||
|             self.save_on_success = False |  | ||||||
|             self.set_uid(uid) |  | ||||||
|             self.set_status(TaskStatus.ERROR if should_fail else TaskStatus.SUCCESSFUL) |  | ||||||
|  |  | ||||||
|         # First test successful run |  | ||||||
|         should_fail = False |  | ||||||
|         test_task.delay().get() |  | ||||||
|         self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first()) |  | ||||||
|  |  | ||||||
|         # Then test failed |  | ||||||
|         should_fail = True |  | ||||||
|         test_task.delay().get() |  | ||||||
|         task = DBSystemTask.objects.filter(name="test_task", uid=uid).first() |  | ||||||
|         self.assertEqual(task.status, TaskStatus.ERROR) |  | ||||||
|  |  | ||||||
|         # Then after that, the state should be removed |  | ||||||
|         should_fail = False |  | ||||||
|         test_task.delay().get() |  | ||||||
|         self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first()) |  | ||||||
|  |  | ||||||
|     def test_tasks(self): |  | ||||||
|         """Test Task API""" |  | ||||||
|         clean_expired_models.delay().get() |  | ||||||
|         response = self.client.get(reverse("authentik_api:systemtask-list")) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|         body = loads(response.content) |  | ||||||
|         self.assertTrue(any(task["name"] == "clean_expired_models" for task in body["results"])) |  | ||||||
|  |  | ||||||
|     def test_tasks_single(self): |  | ||||||
|         """Test Task API (read single)""" |  | ||||||
|         clean_expired_models.delay().get() |  | ||||||
|         task = DBSystemTask.objects.filter(name="clean_expired_models").first() |  | ||||||
|         response = self.client.get( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_api:systemtask-detail", |  | ||||||
|                 kwargs={"pk": str(task.pk)}, |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|         body = loads(response.content) |  | ||||||
|         self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.value) |  | ||||||
|         self.assertEqual(body["name"], "clean_expired_models") |  | ||||||
|         response = self.client.get( |  | ||||||
|             reverse("authentik_api:systemtask-detail", kwargs={"pk": "qwerqwer"}) |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 404) |  | ||||||
|  |  | ||||||
|     def test_tasks_run(self): |  | ||||||
|         """Test Task API (run)""" |  | ||||||
|         clean_expired_models.delay().get() |  | ||||||
|         task = DBSystemTask.objects.filter(name="clean_expired_models").first() |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_api:systemtask-run", |  | ||||||
|                 kwargs={"pk": str(task.pk)}, |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 204) |  | ||||||
|  |  | ||||||
|     def test_tasks_run_404(self): |  | ||||||
|         """Test Task API (run, 404)""" |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_api:systemtask-run", |  | ||||||
|                 kwargs={"pk": "qwerqewrqrqewrqewr"}, |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 404) |  | ||||||
| @ -5,13 +5,11 @@ from authentik.events.api.notification_mappings import NotificationWebhookMappin | |||||||
| from authentik.events.api.notification_rules import NotificationRuleViewSet | from authentik.events.api.notification_rules import NotificationRuleViewSet | ||||||
| from authentik.events.api.notification_transports import NotificationTransportViewSet | from authentik.events.api.notification_transports import NotificationTransportViewSet | ||||||
| from authentik.events.api.notifications import NotificationViewSet | from authentik.events.api.notifications import NotificationViewSet | ||||||
| from authentik.events.api.tasks import SystemTaskViewSet |  | ||||||
|  |  | ||||||
| api_urlpatterns = [ | api_urlpatterns = [ | ||||||
|     ("events/events", EventViewSet), |     ("events/events", EventViewSet), | ||||||
|     ("events/notifications", NotificationViewSet), |     ("events/notifications", NotificationViewSet), | ||||||
|     ("events/transports", NotificationTransportViewSet), |     ("events/transports", NotificationTransportViewSet), | ||||||
|     ("events/rules", NotificationRuleViewSet), |     ("events/rules", NotificationRuleViewSet), | ||||||
|     ("events/system_tasks", SystemTaskViewSet), |  | ||||||
|     ("propertymappings/notification", NotificationWebhookMappingViewSet), |     ("propertymappings/notification", NotificationWebhookMappingViewSet), | ||||||
| ] | ] | ||||||
|  | |||||||
| @ -41,6 +41,7 @@ REDIS_ENV_KEYS = [ | |||||||
| # Old key -> new key | # Old key -> new key | ||||||
| DEPRECATIONS = { | DEPRECATIONS = { | ||||||
|     "geoip": "events.context_processors.geoip", |     "geoip": "events.context_processors.geoip", | ||||||
|  |     "worker.concurrency": "worker.processes", | ||||||
|     "redis.broker_url": "broker.url", |     "redis.broker_url": "broker.url", | ||||||
|     "redis.broker_transport_options": "broker.transport_options", |     "redis.broker_transport_options": "broker.transport_options", | ||||||
|     "redis.cache_timeout": "cache.timeout", |     "redis.cache_timeout": "cache.timeout", | ||||||
|  | |||||||
| @ -21,6 +21,10 @@ def start_debug_server(**kwargs) -> bool: | |||||||
|  |  | ||||||
|     listen: str = CONFIG.get("listen.listen_debug_py", "127.0.0.1:9901") |     listen: str = CONFIG.get("listen.listen_debug_py", "127.0.0.1:9901") | ||||||
|     host, _, port = listen.rpartition(":") |     host, _, port = listen.rpartition(":") | ||||||
|  |     try: | ||||||
|         debugpy.listen((host, int(port)), **kwargs)  # nosec |         debugpy.listen((host, int(port)), **kwargs)  # nosec | ||||||
|  |     except RuntimeError: | ||||||
|  |         LOGGER.warning("Could not start debug server. Continuing without") | ||||||
|  |         return False | ||||||
|     LOGGER.debug("Starting debug server", host=host, port=port) |     LOGGER.debug("Starting debug server", host=host, port=port) | ||||||
|     return True |     return True | ||||||
|  | |||||||
| @ -157,7 +157,14 @@ web: | |||||||
|   path: / |   path: / | ||||||
|  |  | ||||||
| worker: | worker: | ||||||
|   concurrency: 2 |   processes: 2 | ||||||
|  |   threads: 1 | ||||||
|  |   consumer_listen_timeout: "seconds=30" | ||||||
|  |   task_max_retries: 20 | ||||||
|  |   task_default_time_limit: "minutes=10" | ||||||
|  |   task_purge_interval: "days=1" | ||||||
|  |   task_expiration: "days=30" | ||||||
|  |   scheduler_interval: "seconds=60" | ||||||
|  |  | ||||||
| storage: | storage: | ||||||
|   media: |   media: | ||||||
|  | |||||||
| @ -88,7 +88,6 @@ def get_logger_config(): | |||||||
|         "authentik": global_level, |         "authentik": global_level, | ||||||
|         "django": "WARNING", |         "django": "WARNING", | ||||||
|         "django.request": "ERROR", |         "django.request": "ERROR", | ||||||
|         "celery": "WARNING", |  | ||||||
|         "selenium": "WARNING", |         "selenium": "WARNING", | ||||||
|         "docker": "WARNING", |         "docker": "WARNING", | ||||||
|         "urllib3": "WARNING", |         "urllib3": "WARNING", | ||||||
|  | |||||||
| @ -3,8 +3,6 @@ | |||||||
| from asyncio.exceptions import CancelledError | from asyncio.exceptions import CancelledError | ||||||
| from typing import Any | from typing import Any | ||||||
|  |  | ||||||
| from billiard.exceptions import SoftTimeLimitExceeded, WorkerLostError |  | ||||||
| from celery.exceptions import CeleryError |  | ||||||
| from channels_redis.core import ChannelFull | from channels_redis.core import ChannelFull | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError | from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError | ||||||
| @ -22,7 +20,6 @@ from sentry_sdk import HttpTransport, get_current_scope | |||||||
| from sentry_sdk import init as sentry_sdk_init | from sentry_sdk import init as sentry_sdk_init | ||||||
| from sentry_sdk.api import set_tag | from sentry_sdk.api import set_tag | ||||||
| from sentry_sdk.integrations.argv import ArgvIntegration | from sentry_sdk.integrations.argv import ArgvIntegration | ||||||
| from sentry_sdk.integrations.celery import CeleryIntegration |  | ||||||
| from sentry_sdk.integrations.django import DjangoIntegration | from sentry_sdk.integrations.django import DjangoIntegration | ||||||
| from sentry_sdk.integrations.redis import RedisIntegration | from sentry_sdk.integrations.redis import RedisIntegration | ||||||
| from sentry_sdk.integrations.socket import SocketIntegration | from sentry_sdk.integrations.socket import SocketIntegration | ||||||
| @ -71,10 +68,6 @@ ignored_classes = ( | |||||||
|     LocalProtocolError, |     LocalProtocolError, | ||||||
|     # rest_framework error |     # rest_framework error | ||||||
|     APIException, |     APIException, | ||||||
|     # celery errors |  | ||||||
|     WorkerLostError, |  | ||||||
|     CeleryError, |  | ||||||
|     SoftTimeLimitExceeded, |  | ||||||
|     # custom baseclass |     # custom baseclass | ||||||
|     SentryIgnoredException, |     SentryIgnoredException, | ||||||
|     # ldap errors |     # ldap errors | ||||||
| @ -115,7 +108,6 @@ def sentry_init(**sentry_init_kwargs): | |||||||
|             ArgvIntegration(), |             ArgvIntegration(), | ||||||
|             StdlibIntegration(), |             StdlibIntegration(), | ||||||
|             DjangoIntegration(transaction_style="function_name", cache_spans=True), |             DjangoIntegration(transaction_style="function_name", cache_spans=True), | ||||||
|             CeleryIntegration(), |  | ||||||
|             RedisIntegration(), |             RedisIntegration(), | ||||||
|             ThreadingIntegration(propagate_hub=True), |             ThreadingIntegration(propagate_hub=True), | ||||||
|             SocketIntegration(), |             SocketIntegration(), | ||||||
| @ -160,14 +152,11 @@ def before_send(event: dict, hint: dict) -> dict | None: | |||||||
|             return None |             return None | ||||||
|     if "logger" in event: |     if "logger" in event: | ||||||
|         if event["logger"] in [ |         if event["logger"] in [ | ||||||
|             "kombu", |  | ||||||
|             "asyncio", |             "asyncio", | ||||||
|             "multiprocessing", |             "multiprocessing", | ||||||
|             "django_redis", |             "django_redis", | ||||||
|             "django.security.DisallowedHost", |             "django.security.DisallowedHost", | ||||||
|             "django_redis.cache", |             "django_redis.cache", | ||||||
|             "celery.backends.redis", |  | ||||||
|             "celery.worker", |  | ||||||
|             "paramiko.transport", |             "paramiko.transport", | ||||||
|         ]: |         ]: | ||||||
|             return None |             return None | ||||||
|  | |||||||
							
								
								
									
										12
									
								
								authentik/lib/sync/api.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								authentik/lib/sync/api.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,12 @@ | |||||||
|  | from rest_framework.fields import BooleanField, ChoiceField, DateTimeField | ||||||
|  |  | ||||||
|  | from authentik.core.api.utils import PassiveSerializer | ||||||
|  | from authentik.tasks.models import TaskStatus | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SyncStatusSerializer(PassiveSerializer): | ||||||
|  |     """Provider/source sync status""" | ||||||
|  |  | ||||||
|  |     is_running = BooleanField() | ||||||
|  |     last_successful_sync = DateTimeField(required=False) | ||||||
|  |     last_sync_status = ChoiceField(required=False, choices=TaskStatus.choices) | ||||||
| @ -1,7 +1,7 @@ | |||||||
| """Sync constants""" | """Sync constants""" | ||||||
|  |  | ||||||
| PAGE_SIZE = 100 | PAGE_SIZE = 100 | ||||||
| PAGE_TIMEOUT = 60 * 60 * 0.5  # Half an hour | PAGE_TIMEOUT_MS = 60 * 60 * 0.5 * 1000  # Half an hour | ||||||
| HTTP_CONFLICT = 409 | HTTP_CONFLICT = 409 | ||||||
| HTTP_NO_CONTENT = 204 | HTTP_NO_CONTENT = 204 | ||||||
| HTTP_SERVICE_UNAVAILABLE = 503 | HTTP_SERVICE_UNAVAILABLE = 503 | ||||||
|  | |||||||
| @ -1,7 +1,5 @@ | |||||||
| from celery import Task | from dramatiq.actor import Actor | ||||||
| from django.utils.text import slugify | from drf_spectacular.utils import extend_schema | ||||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema |  | ||||||
| from guardian.shortcuts import get_objects_for_user |  | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
| from rest_framework.fields import BooleanField, CharField, ChoiceField | from rest_framework.fields import BooleanField, CharField, ChoiceField | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| @ -9,18 +7,12 @@ from rest_framework.response import Response | |||||||
|  |  | ||||||
| from authentik.core.api.utils import ModelSerializer, PassiveSerializer | from authentik.core.api.utils import ModelSerializer, PassiveSerializer | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.events.api.tasks import SystemTaskSerializer | from authentik.events.logs import LogEventSerializer | ||||||
| from authentik.events.logs import LogEvent, LogEventSerializer | from authentik.lib.sync.api import SyncStatusSerializer | ||||||
| from authentik.lib.sync.outgoing.models import OutgoingSyncProvider | from authentik.lib.sync.outgoing.models import OutgoingSyncProvider | ||||||
| from authentik.lib.utils.reflection import class_to_path | from authentik.lib.utils.reflection import class_to_path | ||||||
| from authentik.rbac.filters import ObjectFilter | from authentik.rbac.filters import ObjectFilter | ||||||
|  | from authentik.tasks.models import Task, TaskStatus | ||||||
|  |  | ||||||
| class SyncStatusSerializer(PassiveSerializer): |  | ||||||
|     """Provider sync status""" |  | ||||||
|  |  | ||||||
|     is_running = BooleanField(read_only=True) |  | ||||||
|     tasks = SystemTaskSerializer(many=True, read_only=True) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SyncObjectSerializer(PassiveSerializer): | class SyncObjectSerializer(PassiveSerializer): | ||||||
| @ -45,15 +37,10 @@ class SyncObjectResultSerializer(PassiveSerializer): | |||||||
| class OutgoingSyncProviderStatusMixin: | class OutgoingSyncProviderStatusMixin: | ||||||
|     """Common API Endpoints for Outgoing sync providers""" |     """Common API Endpoints for Outgoing sync providers""" | ||||||
|  |  | ||||||
|     sync_single_task: type[Task] = None |     sync_task: Actor | ||||||
|     sync_objects_task: type[Task] = None |     sync_objects_task: Actor | ||||||
|  |  | ||||||
|     @extend_schema( |     @extend_schema(responses={200: SyncStatusSerializer()}) | ||||||
|         responses={ |  | ||||||
|             200: SyncStatusSerializer(), |  | ||||||
|             404: OpenApiResponse(description="Task not found"), |  | ||||||
|         } |  | ||||||
|     ) |  | ||||||
|     @action( |     @action( | ||||||
|         methods=["GET"], |         methods=["GET"], | ||||||
|         detail=True, |         detail=True, | ||||||
| @ -64,18 +51,39 @@ class OutgoingSyncProviderStatusMixin: | |||||||
|     def sync_status(self, request: Request, pk: int) -> Response: |     def sync_status(self, request: Request, pk: int) -> Response: | ||||||
|         """Get provider's sync status""" |         """Get provider's sync status""" | ||||||
|         provider: OutgoingSyncProvider = self.get_object() |         provider: OutgoingSyncProvider = self.get_object() | ||||||
|         tasks = list( |  | ||||||
|             get_objects_for_user(request.user, "authentik_events.view_systemtask").filter( |         status = {} | ||||||
|                 name=self.sync_single_task.__name__, |  | ||||||
|                 uid=slugify(provider.name), |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         with provider.sync_lock as lock_acquired: |         with provider.sync_lock as lock_acquired: | ||||||
|             status = { |  | ||||||
|                 "tasks": tasks, |  | ||||||
|             # If we could not acquire the lock, it means a task is using it, and thus is running |             # If we could not acquire the lock, it means a task is using it, and thus is running | ||||||
|                 "is_running": not lock_acquired, |             status["is_running"] = not lock_acquired | ||||||
|             } |  | ||||||
|  |         sync_schedule = None | ||||||
|  |         for schedule in provider.schedules.all(): | ||||||
|  |             if schedule.actor_name == self.sync_task.actor_name: | ||||||
|  |                 sync_schedule = schedule | ||||||
|  |  | ||||||
|  |         if not sync_schedule: | ||||||
|  |             return Response(SyncStatusSerializer(status).data) | ||||||
|  |  | ||||||
|  |         last_task: Task = ( | ||||||
|  |             sync_schedule.tasks.exclude( | ||||||
|  |                 aggregated_status__in=(TaskStatus.CONSUMED, TaskStatus.QUEUED) | ||||||
|  |             ) | ||||||
|  |             .order_by("-mtime") | ||||||
|  |             .first() | ||||||
|  |         ) | ||||||
|  |         last_successful_task: Task = ( | ||||||
|  |             sync_schedule.tasks.filter(aggregated_status__in=(TaskStatus.DONE, TaskStatus.INFO)) | ||||||
|  |             .order_by("-mtime") | ||||||
|  |             .first() | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         if last_task: | ||||||
|  |             status["last_sync_status"] = last_task.aggregated_status | ||||||
|  |         if last_successful_task: | ||||||
|  |             status["last_successful_sync"] = last_successful_task.mtime | ||||||
|  |  | ||||||
|         return Response(SyncStatusSerializer(status).data) |         return Response(SyncStatusSerializer(status).data) | ||||||
|  |  | ||||||
|     @extend_schema( |     @extend_schema( | ||||||
| @ -94,14 +102,20 @@ class OutgoingSyncProviderStatusMixin: | |||||||
|         provider: OutgoingSyncProvider = self.get_object() |         provider: OutgoingSyncProvider = self.get_object() | ||||||
|         params = SyncObjectSerializer(data=request.data) |         params = SyncObjectSerializer(data=request.data) | ||||||
|         params.is_valid(raise_exception=True) |         params.is_valid(raise_exception=True) | ||||||
|         res: list[LogEvent] = self.sync_objects_task.delay( |         msg = self.sync_objects_task.send_with_options( | ||||||
|             params.validated_data["sync_object_model"], |             kwargs={ | ||||||
|             page=1, |                 "object_type": params.validated_data["sync_object_model"], | ||||||
|             provider_pk=provider.pk, |                 "page": 1, | ||||||
|             pk=params.validated_data["sync_object_id"], |                 "provider_pk": provider.pk, | ||||||
|             override_dry_run=params.validated_data["override_dry_run"], |                 "override_dry_run": params.validated_data["override_dry_run"], | ||||||
|         ).get() |                 "pk": params.validated_data["sync_object_id"], | ||||||
|         return Response(SyncObjectResultSerializer(instance={"messages": res}).data) |             }, | ||||||
|  |             rel_obj=provider, | ||||||
|  |         ) | ||||||
|  |         msg.get_result(block=True) | ||||||
|  |         task: Task = msg.options["task"] | ||||||
|  |         task.refresh_from_db() | ||||||
|  |         return Response(SyncObjectResultSerializer(instance={"messages": task._messages}).data) | ||||||
|  |  | ||||||
|  |  | ||||||
| class OutgoingSyncConnectionCreateMixin: | class OutgoingSyncConnectionCreateMixin: | ||||||
|  | |||||||
| @ -1,12 +1,18 @@ | |||||||
| from typing import Any, Self | from typing import Any, Self | ||||||
|  |  | ||||||
| import pglock | import pglock | ||||||
|  | from django.core.paginator import Paginator | ||||||
| from django.db import connection, models | from django.db import connection, models | ||||||
| from django.db.models import Model, QuerySet, TextChoices | from django.db.models import Model, QuerySet, TextChoices | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from dramatiq.actor import Actor | ||||||
|  |  | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
|  | from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT_MS | ||||||
| from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient | from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
|  | from authentik.tasks.schedules.lib import ScheduleSpec | ||||||
|  | from authentik.tasks.schedules.models import ScheduledModel | ||||||
|  |  | ||||||
|  |  | ||||||
| class OutgoingSyncDeleteAction(TextChoices): | class OutgoingSyncDeleteAction(TextChoices): | ||||||
| @ -18,7 +24,7 @@ class OutgoingSyncDeleteAction(TextChoices): | |||||||
|     SUSPEND = "suspend" |     SUSPEND = "suspend" | ||||||
|  |  | ||||||
|  |  | ||||||
| class OutgoingSyncProvider(Model): | class OutgoingSyncProvider(ScheduledModel, Model): | ||||||
|     """Base abstract models for providers implementing outgoing sync""" |     """Base abstract models for providers implementing outgoing sync""" | ||||||
|  |  | ||||||
|     dry_run = models.BooleanField( |     dry_run = models.BooleanField( | ||||||
| @ -39,6 +45,19 @@ class OutgoingSyncProvider(Model): | |||||||
|     def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]: |     def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]: | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def get_paginator[T: User | Group](self, type: type[T]) -> Paginator: | ||||||
|  |         return Paginator(self.get_object_qs(type), PAGE_SIZE) | ||||||
|  |  | ||||||
|  |     def get_object_sync_time_limit_ms[T: User | Group](self, type: type[T]) -> int: | ||||||
|  |         num_pages: int = self.get_paginator(type).num_pages | ||||||
|  |         return int(num_pages * PAGE_TIMEOUT_MS * 1.5) | ||||||
|  |  | ||||||
|  |     def get_sync_time_limit_ms(self) -> int: | ||||||
|  |         return int( | ||||||
|  |             (self.get_object_sync_time_limit_ms(User) + self.get_object_sync_time_limit_ms(Group)) | ||||||
|  |             * 1.5 | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def sync_lock(self) -> pglock.advisory: |     def sync_lock(self) -> pglock.advisory: | ||||||
|         """Postgres lock for syncing to prevent multiple parallel syncs happening""" |         """Postgres lock for syncing to prevent multiple parallel syncs happening""" | ||||||
| @ -47,3 +66,22 @@ class OutgoingSyncProvider(Model): | |||||||
|             timeout=0, |             timeout=0, | ||||||
|             side_effect=pglock.Return, |             side_effect=pglock.Return, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def sync_actor(self) -> Actor: | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def schedule_specs(self) -> list[ScheduleSpec]: | ||||||
|  |         return [ | ||||||
|  |             ScheduleSpec( | ||||||
|  |                 actor=self.sync_actor, | ||||||
|  |                 uid=self.pk, | ||||||
|  |                 args=(self.pk,), | ||||||
|  |                 options={ | ||||||
|  |                     "time_limit": self.get_sync_time_limit_ms(), | ||||||
|  |                 }, | ||||||
|  |                 send_on_save=True, | ||||||
|  |                 crontab=f"{fqdn_rand(self.pk)} */4 * * *", | ||||||
|  |             ), | ||||||
|  |         ] | ||||||
|  | |||||||
| @ -1,12 +1,8 @@ | |||||||
| from collections.abc import Callable |  | ||||||
|  |  | ||||||
| from django.core.paginator import Paginator |  | ||||||
| from django.db.models import Model | from django.db.models import Model | ||||||
| from django.db.models.query import Q |  | ||||||
| from django.db.models.signals import m2m_changed, post_save, pre_delete | from django.db.models.signals import m2m_changed, post_save, pre_delete | ||||||
|  | from dramatiq.actor import Actor | ||||||
|  |  | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT |  | ||||||
| from authentik.lib.sync.outgoing.base import Direction | from authentik.lib.sync.outgoing.base import Direction | ||||||
| from authentik.lib.sync.outgoing.models import OutgoingSyncProvider | from authentik.lib.sync.outgoing.models import OutgoingSyncProvider | ||||||
| from authentik.lib.utils.reflection import class_to_path | from authentik.lib.utils.reflection import class_to_path | ||||||
| @ -14,45 +10,30 @@ from authentik.lib.utils.reflection import class_to_path | |||||||
|  |  | ||||||
| def register_signals( | def register_signals( | ||||||
|     provider_type: type[OutgoingSyncProvider], |     provider_type: type[OutgoingSyncProvider], | ||||||
|     task_sync_single: Callable[[int], None], |     task_sync_direct_dispatch: Actor[[str, str | int, str], None], | ||||||
|     task_sync_direct: Callable[[int], None], |     task_sync_m2m_dispatch: Actor[[str, str, list[str], bool], None], | ||||||
|     task_sync_m2m: Callable[[int], None], |  | ||||||
| ): | ): | ||||||
|     """Register sync signals""" |     """Register sync signals""" | ||||||
|     uid = class_to_path(provider_type) |     uid = class_to_path(provider_type) | ||||||
|  |  | ||||||
|     def post_save_provider(sender: type[Model], instance: OutgoingSyncProvider, created: bool, **_): |  | ||||||
|         """Trigger sync when Provider is saved""" |  | ||||||
|         users_paginator = Paginator(instance.get_object_qs(User), PAGE_SIZE) |  | ||||||
|         groups_paginator = Paginator(instance.get_object_qs(Group), PAGE_SIZE) |  | ||||||
|         soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT |  | ||||||
|         time_limit = soft_time_limit * 1.5 |  | ||||||
|         task_sync_single.apply_async( |  | ||||||
|             (instance.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     post_save.connect(post_save_provider, provider_type, dispatch_uid=uid, weak=False) |  | ||||||
|  |  | ||||||
|     def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_): |     def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_): | ||||||
|         """Post save handler""" |         """Post save handler""" | ||||||
|         if not provider_type.objects.filter( |         task_sync_direct_dispatch.send( | ||||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False) |             class_to_path(instance.__class__), | ||||||
|         ).exists(): |             instance.pk, | ||||||
|             return |             Direction.add.value, | ||||||
|         task_sync_direct.delay(class_to_path(instance.__class__), instance.pk, Direction.add.value) |         ) | ||||||
|  |  | ||||||
|     post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False) |     post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False) | ||||||
|     post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False) |     post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False) | ||||||
|  |  | ||||||
|     def model_pre_delete(sender: type[Model], instance: User | Group, **_): |     def model_pre_delete(sender: type[Model], instance: User | Group, **_): | ||||||
|         """Pre-delete handler""" |         """Pre-delete handler""" | ||||||
|         if not provider_type.objects.filter( |         task_sync_direct_dispatch.send( | ||||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False) |             class_to_path(instance.__class__), | ||||||
|         ).exists(): |             instance.pk, | ||||||
|             return |             Direction.remove.value, | ||||||
|         task_sync_direct.delay( |         ) | ||||||
|             class_to_path(instance.__class__), instance.pk, Direction.remove.value |  | ||||||
|         ).get(propagate=False) |  | ||||||
|  |  | ||||||
|     pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False) |     pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False) | ||||||
|     pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False) |     pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False) | ||||||
| @ -63,16 +44,6 @@ def register_signals( | |||||||
|         """Sync group membership""" |         """Sync group membership""" | ||||||
|         if action not in ["post_add", "post_remove"]: |         if action not in ["post_add", "post_remove"]: | ||||||
|             return |             return | ||||||
|         if not provider_type.objects.filter( |         task_sync_m2m_dispatch.send(instance.pk, action, list(pk_set), reverse) | ||||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False) |  | ||||||
|         ).exists(): |  | ||||||
|             return |  | ||||||
|         # reverse: instance is a Group, pk_set is a list of user pks |  | ||||||
|         # non-reverse: instance is a User, pk_set is a list of groups |  | ||||||
|         if reverse: |  | ||||||
|             task_sync_m2m.delay(str(instance.pk), action, list(pk_set)) |  | ||||||
|         else: |  | ||||||
|             for group_pk in pk_set: |  | ||||||
|                 task_sync_m2m.delay(group_pk, action, [instance.pk]) |  | ||||||
|  |  | ||||||
|     m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False) |     m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False) | ||||||
|  | |||||||
| @ -1,23 +1,17 @@ | |||||||
| from collections.abc import Callable |  | ||||||
| from dataclasses import asdict |  | ||||||
|  |  | ||||||
| from celery import group |  | ||||||
| from celery.exceptions import Retry |  | ||||||
| from celery.result import allow_join_result |  | ||||||
| from django.core.paginator import Paginator | from django.core.paginator import Paginator | ||||||
| from django.db.models import Model, QuerySet | from django.db.models import Model, QuerySet | ||||||
| from django.db.models.query import Q | from django.db.models.query import Q | ||||||
| from django.utils.text import slugify | from django.utils.text import slugify | ||||||
| from django.utils.translation import gettext_lazy as _ | from django_dramatiq_postgres.middleware import CurrentTask | ||||||
|  | from dramatiq.actor import Actor | ||||||
|  | from dramatiq.composition import group | ||||||
|  | from dramatiq.errors import Retry | ||||||
| from structlog.stdlib import BoundLogger, get_logger | from structlog.stdlib import BoundLogger, get_logger | ||||||
|  |  | ||||||
| from authentik.core.expression.exceptions import SkipObjectException | from authentik.core.expression.exceptions import SkipObjectException | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.events.logs import LogEvent |  | ||||||
| from authentik.events.models import TaskStatus |  | ||||||
| from authentik.events.system_tasks import SystemTask |  | ||||||
| from authentik.events.utils import sanitize_item | from authentik.events.utils import sanitize_item | ||||||
| from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT | from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT_MS | ||||||
| from authentik.lib.sync.outgoing.base import Direction | from authentik.lib.sync.outgoing.base import Direction | ||||||
| from authentik.lib.sync.outgoing.exceptions import ( | from authentik.lib.sync.outgoing.exceptions import ( | ||||||
|     BadRequestSyncException, |     BadRequestSyncException, | ||||||
| @ -27,11 +21,12 @@ from authentik.lib.sync.outgoing.exceptions import ( | |||||||
| ) | ) | ||||||
| from authentik.lib.sync.outgoing.models import OutgoingSyncProvider | from authentik.lib.sync.outgoing.models import OutgoingSyncProvider | ||||||
| from authentik.lib.utils.reflection import class_to_path, path_to_class | from authentik.lib.utils.reflection import class_to_path, path_to_class | ||||||
|  | from authentik.tasks.models import Task | ||||||
|  |  | ||||||
|  |  | ||||||
| class SyncTasks: | class SyncTasks: | ||||||
|     """Container for all sync 'tasks' (this class doesn't actually contain celery |     """Container for all sync 'tasks' (this class doesn't actually contain | ||||||
|     tasks due to celery's magic, however exposes a number of functions to be called from tasks)""" |     tasks due to dramatiq's magic, however exposes a number of functions to be called from tasks)""" | ||||||
|  |  | ||||||
|     logger: BoundLogger |     logger: BoundLogger | ||||||
|  |  | ||||||
| @ -39,107 +34,104 @@ class SyncTasks: | |||||||
|         super().__init__() |         super().__init__() | ||||||
|         self._provider_model = provider_model |         self._provider_model = provider_model | ||||||
|  |  | ||||||
|     def sync_all(self, single_sync: Callable[[int], None]): |     def sync_paginator( | ||||||
|         for provider in self._provider_model.objects.filter( |  | ||||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False) |  | ||||||
|         ): |  | ||||||
|             self.trigger_single_task(provider, single_sync) |  | ||||||
|  |  | ||||||
|     def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]): |  | ||||||
|         """Wrapper single sync task that correctly sets time limits based |  | ||||||
|         on the amount of objects that will be synced""" |  | ||||||
|         users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE) |  | ||||||
|         groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE) |  | ||||||
|         soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT |  | ||||||
|         time_limit = soft_time_limit * 1.5 |  | ||||||
|         return sync_task.apply_async( |  | ||||||
|             (provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def sync_single( |  | ||||||
|         self, |         self, | ||||||
|         task: SystemTask, |         current_task: Task, | ||||||
|         provider_pk: int, |         provider: OutgoingSyncProvider, | ||||||
|         sync_objects: Callable[[int, int], list[str]], |         sync_objects: Actor[[str, int, int, bool], None], | ||||||
|  |         paginator: Paginator, | ||||||
|  |         object_type: type[User | Group], | ||||||
|  |         **options, | ||||||
|     ): |     ): | ||||||
|  |         tasks = [] | ||||||
|  |         for page in paginator.page_range: | ||||||
|  |             page_sync = sync_objects.message_with_options( | ||||||
|  |                 args=(class_to_path(object_type), page, provider.pk), | ||||||
|  |                 time_limit=PAGE_TIMEOUT_MS, | ||||||
|  |                 # Assign tasks to the same schedule as the current one | ||||||
|  |                 rel_obj=current_task.rel_obj, | ||||||
|  |                 **options, | ||||||
|  |             ) | ||||||
|  |             tasks.append(page_sync) | ||||||
|  |         return tasks | ||||||
|  |  | ||||||
|  |     def sync( | ||||||
|  |         self, | ||||||
|  |         provider_pk: int, | ||||||
|  |         sync_objects: Actor[[str, int, int, bool], None], | ||||||
|  |     ): | ||||||
|  |         task: Task = CurrentTask.get_task() | ||||||
|         self.logger = get_logger().bind( |         self.logger = get_logger().bind( | ||||||
|             provider_type=class_to_path(self._provider_model), |             provider_type=class_to_path(self._provider_model), | ||||||
|             provider_pk=provider_pk, |             provider_pk=provider_pk, | ||||||
|         ) |         ) | ||||||
|         provider = self._provider_model.objects.filter( |         provider: OutgoingSyncProvider = self._provider_model.objects.filter( | ||||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False), |             Q(backchannel_application__isnull=False) | Q(application__isnull=False), | ||||||
|             pk=provider_pk, |             pk=provider_pk, | ||||||
|         ).first() |         ).first() | ||||||
|         if not provider: |         if not provider: | ||||||
|  |             task.warning("No provider found. Is it assigned to an application?") | ||||||
|             return |             return | ||||||
|         task.set_uid(slugify(provider.name)) |         task.set_uid(slugify(provider.name)) | ||||||
|         messages = [] |         task.info("Starting full provider sync") | ||||||
|         messages.append(_("Starting full provider sync")) |  | ||||||
|         self.logger.debug("Starting provider sync") |         self.logger.debug("Starting provider sync") | ||||||
|         users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE) |         with provider.sync_lock as lock_acquired: | ||||||
|         groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE) |  | ||||||
|         with allow_join_result(), provider.sync_lock as lock_acquired: |  | ||||||
|             if not lock_acquired: |             if not lock_acquired: | ||||||
|  |                 task.info("Synchronization is already running. Skipping.") | ||||||
|                 self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name) |                 self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name) | ||||||
|                 return |                 return | ||||||
|             try: |             try: | ||||||
|                 messages.append(_("Syncing users")) |                 users_tasks = group( | ||||||
|                 user_results = ( |                     self.sync_paginator( | ||||||
|                     group( |                         current_task=task, | ||||||
|                         [ |                         provider=provider, | ||||||
|                             sync_objects.signature( |                         sync_objects=sync_objects, | ||||||
|                                 args=(class_to_path(User), page, provider_pk), |                         paginator=provider.get_paginator(User), | ||||||
|                                 time_limit=PAGE_TIMEOUT, |                         object_type=User, | ||||||
|                                 soft_time_limit=PAGE_TIMEOUT, |  | ||||||
|                     ) |                     ) | ||||||
|                             for page in users_paginator.page_range |  | ||||||
|                         ] |  | ||||||
|                 ) |                 ) | ||||||
|                     .apply_async() |                 group_tasks = group( | ||||||
|                     .get() |                     self.sync_paginator( | ||||||
|  |                         current_task=task, | ||||||
|  |                         provider=provider, | ||||||
|  |                         sync_objects=sync_objects, | ||||||
|  |                         paginator=provider.get_paginator(Group), | ||||||
|  |                         object_type=Group, | ||||||
|                     ) |                     ) | ||||||
|                 for result in user_results: |  | ||||||
|                     for msg in result: |  | ||||||
|                         messages.append(LogEvent(**msg)) |  | ||||||
|                 messages.append(_("Syncing groups")) |  | ||||||
|                 group_results = ( |  | ||||||
|                     group( |  | ||||||
|                         [ |  | ||||||
|                             sync_objects.signature( |  | ||||||
|                                 args=(class_to_path(Group), page, provider_pk), |  | ||||||
|                                 time_limit=PAGE_TIMEOUT, |  | ||||||
|                                 soft_time_limit=PAGE_TIMEOUT, |  | ||||||
|                 ) |                 ) | ||||||
|                             for page in groups_paginator.page_range |                 users_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(User)) | ||||||
|                         ] |                 group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group)) | ||||||
|                     ) |  | ||||||
|                     .apply_async() |  | ||||||
|                     .get() |  | ||||||
|                 ) |  | ||||||
|                 for result in group_results: |  | ||||||
|                     for msg in result: |  | ||||||
|                         messages.append(LogEvent(**msg)) |  | ||||||
|             except TransientSyncException as exc: |             except TransientSyncException as exc: | ||||||
|                 self.logger.warning("transient sync exception", exc=exc) |                 self.logger.warning("transient sync exception", exc=exc) | ||||||
|                 raise task.retry(exc=exc) from exc |                 task.warning("Sync encountered a transient exception. Retrying", exc=exc) | ||||||
|  |                 raise Retry() from exc | ||||||
|             except StopSync as exc: |             except StopSync as exc: | ||||||
|                 task.set_error(exc) |                 task.error(exc) | ||||||
|                 return |                 return | ||||||
|         task.set_status(TaskStatus.SUCCESSFUL, *messages) |  | ||||||
|  |  | ||||||
|     def sync_objects( |     def sync_objects( | ||||||
|         self, object_type: str, page: int, provider_pk: int, override_dry_run=False, **filter |         self, | ||||||
|  |         object_type: str, | ||||||
|  |         page: int, | ||||||
|  |         provider_pk: int, | ||||||
|  |         override_dry_run=False, | ||||||
|  |         **filter, | ||||||
|     ): |     ): | ||||||
|  |         task: Task = CurrentTask.get_task() | ||||||
|         _object_type: type[Model] = path_to_class(object_type) |         _object_type: type[Model] = path_to_class(object_type) | ||||||
|         self.logger = get_logger().bind( |         self.logger = get_logger().bind( | ||||||
|             provider_type=class_to_path(self._provider_model), |             provider_type=class_to_path(self._provider_model), | ||||||
|             provider_pk=provider_pk, |             provider_pk=provider_pk, | ||||||
|             object_type=object_type, |             object_type=object_type, | ||||||
|         ) |         ) | ||||||
|         messages = [] |         provider: OutgoingSyncProvider = self._provider_model.objects.filter( | ||||||
|         provider = self._provider_model.objects.filter(pk=provider_pk).first() |             Q(backchannel_application__isnull=False) | Q(application__isnull=False), | ||||||
|  |             pk=provider_pk, | ||||||
|  |         ).first() | ||||||
|         if not provider: |         if not provider: | ||||||
|             return messages |             task.warning("No provider found. Is it assigned to an application?") | ||||||
|  |             return | ||||||
|  |         task.set_uid(slugify(provider.name)) | ||||||
|         # Override dry run mode if requested, however don't save the provider |         # Override dry run mode if requested, however don't save the provider | ||||||
|         # so that scheduled sync tasks still run in dry_run mode |         # so that scheduled sync tasks still run in dry_run mode | ||||||
|         if override_dry_run: |         if override_dry_run: | ||||||
| @ -147,25 +139,13 @@ class SyncTasks: | |||||||
|         try: |         try: | ||||||
|             client = provider.client_for_model(_object_type) |             client = provider.client_for_model(_object_type) | ||||||
|         except TransientSyncException: |         except TransientSyncException: | ||||||
|             return messages |             return | ||||||
|         paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE) |         paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE) | ||||||
|         if client.can_discover: |         if client.can_discover: | ||||||
|             self.logger.debug("starting discover") |             self.logger.debug("starting discover") | ||||||
|             client.discover() |             client.discover() | ||||||
|         self.logger.debug("starting sync for page", page=page) |         self.logger.debug("starting sync for page", page=page) | ||||||
|         messages.append( |         task.info(f"Syncing page {page} or {_object_type._meta.verbose_name_plural}") | ||||||
|             asdict( |  | ||||||
|                 LogEvent( |  | ||||||
|                     _( |  | ||||||
|                         "Syncing page {page} of {object_type}".format( |  | ||||||
|                             page=page, object_type=_object_type._meta.verbose_name_plural |  | ||||||
|                         ) |  | ||||||
|                     ), |  | ||||||
|                     log_level="info", |  | ||||||
|                     logger=f"{provider._meta.verbose_name}@{object_type}", |  | ||||||
|                 ) |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         for obj in paginator.page(page).object_list: |         for obj in paginator.page(page).object_list: | ||||||
|             obj: Model |             obj: Model | ||||||
|             try: |             try: | ||||||
| @ -174,89 +154,58 @@ class SyncTasks: | |||||||
|                 self.logger.debug("skipping object due to SkipObject", obj=obj) |                 self.logger.debug("skipping object due to SkipObject", obj=obj) | ||||||
|                 continue |                 continue | ||||||
|             except DryRunRejected as exc: |             except DryRunRejected as exc: | ||||||
|                 messages.append( |                 task.info( | ||||||
|                     asdict( |                     "Dropping mutating request due to dry run", | ||||||
|                         LogEvent( |                     obj=sanitize_item(obj), | ||||||
|                             _("Dropping mutating request due to dry run"), |                     method=exc.method, | ||||||
|                             log_level="info", |                     url=exc.url, | ||||||
|                             logger=f"{provider._meta.verbose_name}@{object_type}", |                     body=exc.body, | ||||||
|                             attributes={ |  | ||||||
|                                 "obj": sanitize_item(obj), |  | ||||||
|                                 "method": exc.method, |  | ||||||
|                                 "url": exc.url, |  | ||||||
|                                 "body": exc.body, |  | ||||||
|                             }, |  | ||||||
|                         ) |  | ||||||
|                     ) |  | ||||||
|                 ) |                 ) | ||||||
|             except BadRequestSyncException as exc: |             except BadRequestSyncException as exc: | ||||||
|                 self.logger.warning("failed to sync object", exc=exc, obj=obj) |                 self.logger.warning("failed to sync object", exc=exc, obj=obj) | ||||||
|                 messages.append( |                 task.warning( | ||||||
|                     asdict( |                     f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to error: {str(exc)}", | ||||||
|                         LogEvent( |                     arguments=exc.args[1:], | ||||||
|                             _( |                     obj=sanitize_item(obj), | ||||||
|                                 ( |  | ||||||
|                                     "Failed to sync {object_type} {object_name} " |  | ||||||
|                                     "due to error: {error}" |  | ||||||
|                                 ).format_map( |  | ||||||
|                                     { |  | ||||||
|                                         "object_type": obj._meta.verbose_name, |  | ||||||
|                                         "object_name": str(obj), |  | ||||||
|                                         "error": str(exc), |  | ||||||
|                                     } |  | ||||||
|                                 ) |  | ||||||
|                             ), |  | ||||||
|                             log_level="warning", |  | ||||||
|                             logger=f"{provider._meta.verbose_name}@{object_type}", |  | ||||||
|                             attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)}, |  | ||||||
|                         ) |  | ||||||
|                     ) |  | ||||||
|                 ) |                 ) | ||||||
|             except TransientSyncException as exc: |             except TransientSyncException as exc: | ||||||
|                 self.logger.warning("failed to sync object", exc=exc, user=obj) |                 self.logger.warning("failed to sync object", exc=exc, user=obj) | ||||||
|                 messages.append( |                 task.warning( | ||||||
|                     asdict( |                     f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to " | ||||||
|                         LogEvent( |                     "transient error: {str(exc)}", | ||||||
|                             _( |                     obj=sanitize_item(obj), | ||||||
|                                 ( |  | ||||||
|                                     "Failed to sync {object_type} {object_name} " |  | ||||||
|                                     "due to transient error: {error}" |  | ||||||
|                                 ).format_map( |  | ||||||
|                                     { |  | ||||||
|                                         "object_type": obj._meta.verbose_name, |  | ||||||
|                                         "object_name": str(obj), |  | ||||||
|                                         "error": str(exc), |  | ||||||
|                                     } |  | ||||||
|                                 ) |  | ||||||
|                             ), |  | ||||||
|                             log_level="warning", |  | ||||||
|                             logger=f"{provider._meta.verbose_name}@{object_type}", |  | ||||||
|                             attributes={"obj": sanitize_item(obj)}, |  | ||||||
|                         ) |  | ||||||
|                     ) |  | ||||||
|                 ) |                 ) | ||||||
|             except StopSync as exc: |             except StopSync as exc: | ||||||
|                 self.logger.warning("Stopping sync", exc=exc) |                 self.logger.warning("Stopping sync", exc=exc) | ||||||
|                 messages.append( |                 task.warning( | ||||||
|                     asdict( |                     f"Stopping sync due to error: {exc.detail()}", | ||||||
|                         LogEvent( |                     obj=sanitize_item(obj), | ||||||
|                             _( |  | ||||||
|                                 "Stopping sync due to error: {error}".format_map( |  | ||||||
|                                     { |  | ||||||
|                                         "error": exc.detail(), |  | ||||||
|                                     } |  | ||||||
|                                 ) |  | ||||||
|                             ), |  | ||||||
|                             log_level="warning", |  | ||||||
|                             logger=f"{provider._meta.verbose_name}@{object_type}", |  | ||||||
|                             attributes={"obj": sanitize_item(obj)}, |  | ||||||
|                         ) |  | ||||||
|                     ) |  | ||||||
|                 ) |                 ) | ||||||
|                 break |                 break | ||||||
|         return messages |  | ||||||
|  |  | ||||||
|     def sync_signal_direct(self, model: str, pk: str | int, raw_op: str): |     def sync_signal_direct_dispatch( | ||||||
|  |         self, | ||||||
|  |         task_sync_signal_direct: Actor[[str, str | int, int, str], None], | ||||||
|  |         model: str, | ||||||
|  |         pk: str | int, | ||||||
|  |         raw_op: str, | ||||||
|  |     ): | ||||||
|  |         for provider in self._provider_model.objects.filter( | ||||||
|  |             Q(backchannel_application__isnull=False) | Q(application__isnull=False) | ||||||
|  |         ): | ||||||
|  |             task_sync_signal_direct.send_with_options( | ||||||
|  |                 args=(model, pk, provider.pk, raw_op), | ||||||
|  |                 rel_obj=provider, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def sync_signal_direct( | ||||||
|  |         self, | ||||||
|  |         model: str, | ||||||
|  |         pk: str | int, | ||||||
|  |         provider_pk: int, | ||||||
|  |         raw_op: str, | ||||||
|  |     ): | ||||||
|  |         task: Task = CurrentTask.get_task() | ||||||
|         self.logger = get_logger().bind( |         self.logger = get_logger().bind( | ||||||
|             provider_type=class_to_path(self._provider_model), |             provider_type=class_to_path(self._provider_model), | ||||||
|         ) |         ) | ||||||
| @ -264,20 +213,25 @@ class SyncTasks: | |||||||
|         instance = model_class.objects.filter(pk=pk).first() |         instance = model_class.objects.filter(pk=pk).first() | ||||||
|         if not instance: |         if not instance: | ||||||
|             return |             return | ||||||
|  |         provider: OutgoingSyncProvider = self._provider_model.objects.filter( | ||||||
|  |             Q(backchannel_application__isnull=False) | Q(application__isnull=False), | ||||||
|  |             pk=provider_pk, | ||||||
|  |         ).first() | ||||||
|  |         if not provider: | ||||||
|  |             task.warning("No provider found. Is it assigned to an application?") | ||||||
|  |             return | ||||||
|  |         task.set_uid(slugify(provider.name)) | ||||||
|         operation = Direction(raw_op) |         operation = Direction(raw_op) | ||||||
|         for provider in self._provider_model.objects.filter( |  | ||||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False) |  | ||||||
|         ): |  | ||||||
|         client = provider.client_for_model(instance.__class__) |         client = provider.client_for_model(instance.__class__) | ||||||
|         # Check if the object is allowed within the provider's restrictions |         # Check if the object is allowed within the provider's restrictions | ||||||
|         queryset = provider.get_object_qs(instance.__class__) |         queryset = provider.get_object_qs(instance.__class__) | ||||||
|         if not queryset: |         if not queryset: | ||||||
|                 continue |             return | ||||||
|  |  | ||||||
|         # The queryset we get from the provider must include the instance we've got given |         # The queryset we get from the provider must include the instance we've got given | ||||||
|         # otherwise ignore this provider |         # otherwise ignore this provider | ||||||
|         if not queryset.filter(pk=instance.pk).exists(): |         if not queryset.filter(pk=instance.pk).exists(): | ||||||
|                 continue |             return | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             if operation == Direction.add: |             if operation == Direction.add: | ||||||
| @ -287,28 +241,66 @@ class SyncTasks: | |||||||
|         except TransientSyncException as exc: |         except TransientSyncException as exc: | ||||||
|             raise Retry() from exc |             raise Retry() from exc | ||||||
|         except SkipObjectException: |         except SkipObjectException: | ||||||
|                 continue |             return | ||||||
|         except DryRunRejected as exc: |         except DryRunRejected as exc: | ||||||
|             self.logger.info("Rejected dry-run event", exc=exc) |             self.logger.info("Rejected dry-run event", exc=exc) | ||||||
|         except StopSync as exc: |         except StopSync as exc: | ||||||
|             self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) |             self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) | ||||||
|  |  | ||||||
|     def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]): |     def sync_signal_m2m_dispatch( | ||||||
|  |         self, | ||||||
|  |         task_sync_signal_m2m: Actor[[str, int, str, list[int]], None], | ||||||
|  |         instance_pk: str, | ||||||
|  |         action: str, | ||||||
|  |         pk_set: list[int], | ||||||
|  |         reverse: bool, | ||||||
|  |     ): | ||||||
|  |         for provider in self._provider_model.objects.filter( | ||||||
|  |             Q(backchannel_application__isnull=False) | Q(application__isnull=False) | ||||||
|  |         ): | ||||||
|  |             # reverse: instance is a Group, pk_set is a list of user pks | ||||||
|  |             # non-reverse: instance is a User, pk_set is a list of groups | ||||||
|  |             if reverse: | ||||||
|  |                 task_sync_signal_m2m.send_with_options( | ||||||
|  |                     args=(instance_pk, provider.pk, action, list(pk_set)), | ||||||
|  |                     rel_obj=provider, | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 for pk in pk_set: | ||||||
|  |                     task_sync_signal_m2m.send_with_options( | ||||||
|  |                         args=(pk, provider.pk, action, [instance_pk]), | ||||||
|  |                         rel_obj=provider, | ||||||
|  |                     ) | ||||||
|  |  | ||||||
|  |     def sync_signal_m2m( | ||||||
|  |         self, | ||||||
|  |         group_pk: str, | ||||||
|  |         provider_pk: int, | ||||||
|  |         action: str, | ||||||
|  |         pk_set: list[int], | ||||||
|  |     ): | ||||||
|  |         task: Task = CurrentTask.get_task() | ||||||
|         self.logger = get_logger().bind( |         self.logger = get_logger().bind( | ||||||
|             provider_type=class_to_path(self._provider_model), |             provider_type=class_to_path(self._provider_model), | ||||||
|         ) |         ) | ||||||
|         group = Group.objects.filter(pk=group_pk).first() |         group = Group.objects.filter(pk=group_pk).first() | ||||||
|         if not group: |         if not group: | ||||||
|             return |             return | ||||||
|         for provider in self._provider_model.objects.filter( |         provider: OutgoingSyncProvider = self._provider_model.objects.filter( | ||||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False) |             Q(backchannel_application__isnull=False) | Q(application__isnull=False), | ||||||
|         ): |             pk=provider_pk, | ||||||
|  |         ).first() | ||||||
|  |         if not provider: | ||||||
|  |             task.warning("No provider found. Is it assigned to an application?") | ||||||
|  |             return | ||||||
|  |         task.set_uid(slugify(provider.name)) | ||||||
|  |  | ||||||
|         # Check if the object is allowed within the provider's restrictions |         # Check if the object is allowed within the provider's restrictions | ||||||
|         queryset: QuerySet = provider.get_object_qs(Group) |         queryset: QuerySet = provider.get_object_qs(Group) | ||||||
|         # The queryset we get from the provider must include the instance we've got given |         # The queryset we get from the provider must include the instance we've got given | ||||||
|         # otherwise ignore this provider |         # otherwise ignore this provider | ||||||
|         if not queryset.filter(pk=group_pk).exists(): |         if not queryset.filter(pk=group_pk).exists(): | ||||||
|                 continue |             return | ||||||
|  |  | ||||||
|         client = provider.client_for_model(Group) |         client = provider.client_for_model(Group) | ||||||
|         try: |         try: | ||||||
| @ -321,7 +313,7 @@ class SyncTasks: | |||||||
|         except TransientSyncException as exc: |         except TransientSyncException as exc: | ||||||
|             raise Retry() from exc |             raise Retry() from exc | ||||||
|         except SkipObjectException: |         except SkipObjectException: | ||||||
|                 continue |             return | ||||||
|         except DryRunRejected as exc: |         except DryRunRejected as exc: | ||||||
|             self.logger.info("Rejected dry-run event", exc=exc) |             self.logger.info("Rejected dry-run event", exc=exc) | ||||||
|         except StopSync as exc: |         except StopSync as exc: | ||||||
|  | |||||||
| @ -5,6 +5,8 @@ from structlog.stdlib import get_logger | |||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
|  | from authentik.tasks.schedules.lib import ScheduleSpec | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| @ -60,3 +62,27 @@ class AuthentikOutpostConfig(ManagedAppConfig): | |||||||
|                 outpost.save() |                 outpost.save() | ||||||
|         else: |         else: | ||||||
|             Outpost.objects.filter(managed=MANAGED_OUTPOST).delete() |             Outpost.objects.filter(managed=MANAGED_OUTPOST).delete() | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def tenant_schedule_specs(self) -> list[ScheduleSpec]: | ||||||
|  |         from authentik.outposts.tasks import outpost_token_ensurer | ||||||
|  |  | ||||||
|  |         return [ | ||||||
|  |             ScheduleSpec( | ||||||
|  |                 actor=outpost_token_ensurer, | ||||||
|  |                 crontab=f"{fqdn_rand('outpost_token_ensurer')} */8 * * *", | ||||||
|  |             ), | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def global_schedule_specs(self) -> list[ScheduleSpec]: | ||||||
|  |         from authentik.outposts.tasks import outpost_connection_discovery | ||||||
|  |  | ||||||
|  |         return [ | ||||||
|  |             ScheduleSpec( | ||||||
|  |                 actor=outpost_connection_discovery, | ||||||
|  |                 crontab=f"{fqdn_rand('outpost_connection_discovery')} */8 * * *", | ||||||
|  |                 send_on_startup=True, | ||||||
|  |                 paused=not CONFIG.get_bool("outposts.discover"), | ||||||
|  |             ), | ||||||
|  |         ] | ||||||
|  | |||||||
| @ -101,7 +101,13 @@ class KubernetesController(BaseController): | |||||||
|             all_logs = [] |             all_logs = [] | ||||||
|             for reconcile_key in self.reconcile_order: |             for reconcile_key in self.reconcile_order: | ||||||
|                 if reconcile_key in self.outpost.config.kubernetes_disabled_components: |                 if reconcile_key in self.outpost.config.kubernetes_disabled_components: | ||||||
|                     all_logs += [f"{reconcile_key.title()}: Disabled"] |                     all_logs.append( | ||||||
|  |                         LogEvent( | ||||||
|  |                             log_level="info", | ||||||
|  |                             event=f"{reconcile_key.title()}: Disabled", | ||||||
|  |                             logger=str(type(self)), | ||||||
|  |                         ) | ||||||
|  |                     ) | ||||||
|                     continue |                     continue | ||||||
|                 with capture_logs() as logs: |                 with capture_logs() as logs: | ||||||
|                     reconciler_cls = self.reconcilers.get(reconcile_key) |                     reconciler_cls = self.reconcilers.get(reconcile_key) | ||||||
| @ -134,7 +140,13 @@ class KubernetesController(BaseController): | |||||||
|             all_logs = [] |             all_logs = [] | ||||||
|             for reconcile_key in self.reconcile_order: |             for reconcile_key in self.reconcile_order: | ||||||
|                 if reconcile_key in self.outpost.config.kubernetes_disabled_components: |                 if reconcile_key in self.outpost.config.kubernetes_disabled_components: | ||||||
|                     all_logs += [f"{reconcile_key.title()}: Disabled"] |                     all_logs.append( | ||||||
|  |                         LogEvent( | ||||||
|  |                             log_level="info", | ||||||
|  |                             event=f"{reconcile_key.title()}: Disabled", | ||||||
|  |                             logger=str(type(self)), | ||||||
|  |                         ) | ||||||
|  |                     ) | ||||||
|                     continue |                     continue | ||||||
|                 with capture_logs() as logs: |                 with capture_logs() as logs: | ||||||
|                     reconciler_cls = self.reconcilers.get(reconcile_key) |                     reconciler_cls = self.reconcilers.get(reconcile_key) | ||||||
|  | |||||||
| @ -36,7 +36,10 @@ from authentik.lib.config import CONFIG | |||||||
| from authentik.lib.models import InheritanceForeignKey, SerializerModel | from authentik.lib.models import InheritanceForeignKey, SerializerModel | ||||||
| from authentik.lib.sentry import SentryIgnoredException | from authentik.lib.sentry import SentryIgnoredException | ||||||
| from authentik.lib.utils.errors import exception_to_string | from authentik.lib.utils.errors import exception_to_string | ||||||
|  | from authentik.lib.utils.time import fqdn_rand | ||||||
| from authentik.outposts.controllers.k8s.utils import get_namespace | from authentik.outposts.controllers.k8s.utils import get_namespace | ||||||
|  | from authentik.tasks.schedules.lib import ScheduleSpec | ||||||
|  | from authentik.tasks.schedules.models import ScheduledModel | ||||||
|  |  | ||||||
| OUR_VERSION = parse(__version__) | OUR_VERSION = parse(__version__) | ||||||
| OUTPOST_HELLO_INTERVAL = 10 | OUTPOST_HELLO_INTERVAL = 10 | ||||||
| @ -115,7 +118,7 @@ class OutpostServiceConnectionState: | |||||||
|     healthy: bool |     healthy: bool | ||||||
|  |  | ||||||
|  |  | ||||||
| class OutpostServiceConnection(models.Model): | class OutpostServiceConnection(ScheduledModel, models.Model): | ||||||
|     """Connection details for an Outpost Controller, like Docker or Kubernetes""" |     """Connection details for an Outpost Controller, like Docker or Kubernetes""" | ||||||
|  |  | ||||||
|     uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) |     uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) | ||||||
| @ -145,11 +148,11 @@ class OutpostServiceConnection(models.Model): | |||||||
|     @property |     @property | ||||||
|     def state(self) -> OutpostServiceConnectionState: |     def state(self) -> OutpostServiceConnectionState: | ||||||
|         """Get state of service connection""" |         """Get state of service connection""" | ||||||
|         from authentik.outposts.tasks import outpost_service_connection_state |         from authentik.outposts.tasks import outpost_service_connection_monitor | ||||||
|  |  | ||||||
|         state = cache.get(self.state_key, None) |         state = cache.get(self.state_key, None) | ||||||
|         if not state: |         if not state: | ||||||
|             outpost_service_connection_state.delay(self.pk) |             outpost_service_connection_monitor.send_with_options(args=(self.pk), rel_obj=self) | ||||||
|             return OutpostServiceConnectionState("", False) |             return OutpostServiceConnectionState("", False) | ||||||
|         return state |         return state | ||||||
|  |  | ||||||
| @ -160,6 +163,20 @@ class OutpostServiceConnection(models.Model): | |||||||
|         # since the response doesn't use the correct inheritance |         # since the response doesn't use the correct inheritance | ||||||
|         return "" |         return "" | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def schedule_specs(self) -> list[ScheduleSpec]: | ||||||
|  |         from authentik.outposts.tasks import outpost_service_connection_monitor | ||||||
|  |  | ||||||
|  |         return [ | ||||||
|  |             ScheduleSpec( | ||||||
|  |                 actor=outpost_service_connection_monitor, | ||||||
|  |                 uid=self.pk, | ||||||
|  |                 args=(self.pk,), | ||||||
|  |                 crontab="3-59/15 * * * *", | ||||||
|  |                 send_on_save=True, | ||||||
|  |             ), | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |  | ||||||
| class DockerServiceConnection(SerializerModel, OutpostServiceConnection): | class DockerServiceConnection(SerializerModel, OutpostServiceConnection): | ||||||
|     """Service Connection to a Docker endpoint""" |     """Service Connection to a Docker endpoint""" | ||||||
| @ -244,7 +261,7 @@ class KubernetesServiceConnection(SerializerModel, OutpostServiceConnection): | |||||||
|         return "ak-service-connection-kubernetes-form" |         return "ak-service-connection-kubernetes-form" | ||||||
|  |  | ||||||
|  |  | ||||||
| class Outpost(SerializerModel, ManagedModel): | class Outpost(ScheduledModel, SerializerModel, ManagedModel): | ||||||
|     """Outpost instance which manages a service user and token""" |     """Outpost instance which manages a service user and token""" | ||||||
|  |  | ||||||
|     uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) |     uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) | ||||||
| @ -298,6 +315,21 @@ class Outpost(SerializerModel, ManagedModel): | |||||||
|         """Username for service user""" |         """Username for service user""" | ||||||
|         return f"ak-outpost-{self.uuid.hex}" |         return f"ak-outpost-{self.uuid.hex}" | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def schedule_specs(self) -> list[ScheduleSpec]: | ||||||
|  |         from authentik.outposts.tasks import outpost_controller | ||||||
|  |  | ||||||
|  |         return [ | ||||||
|  |             ScheduleSpec( | ||||||
|  |                 actor=outpost_controller, | ||||||
|  |                 uid=self.pk, | ||||||
|  |                 args=(self.pk,), | ||||||
|  |                 kwargs={"action": "up", "from_cache": False}, | ||||||
|  |                 crontab=f"{fqdn_rand('outpost_controller')} */4 * * *", | ||||||
|  |                 send_on_save=True, | ||||||
|  |             ), | ||||||
|  |         ] | ||||||
|  |  | ||||||
|     def build_user_permissions(self, user: User): |     def build_user_permissions(self, user: User): | ||||||
|         """Create per-object and global permissions for outpost service-account""" |         """Create per-object and global permissions for outpost service-account""" | ||||||
|         # To ensure the user only has the correct permissions, we delete all of them and re-add |         # To ensure the user only has the correct permissions, we delete all of them and re-add | ||||||
|  | |||||||
| @ -1,28 +0,0 @@ | |||||||
| """Outposts Settings""" |  | ||||||
|  |  | ||||||
| from celery.schedules import crontab |  | ||||||
|  |  | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
|  |  | ||||||
| CELERY_BEAT_SCHEDULE = { |  | ||||||
|     "outposts_controller": { |  | ||||||
|         "task": "authentik.outposts.tasks.outpost_controller_all", |  | ||||||
|         "schedule": crontab(minute=fqdn_rand("outposts_controller"), hour="*/4"), |  | ||||||
|         "options": {"queue": "authentik_scheduled"}, |  | ||||||
|     }, |  | ||||||
|     "outposts_service_connection_check": { |  | ||||||
|         "task": "authentik.outposts.tasks.outpost_service_connection_monitor", |  | ||||||
|         "schedule": crontab(minute="3-59/15"), |  | ||||||
|         "options": {"queue": "authentik_scheduled"}, |  | ||||||
|     }, |  | ||||||
|     "outpost_token_ensurer": { |  | ||||||
|         "task": "authentik.outposts.tasks.outpost_token_ensurer", |  | ||||||
|         "schedule": crontab(minute=fqdn_rand("outpost_token_ensurer"), hour="*/8"), |  | ||||||
|         "options": {"queue": "authentik_scheduled"}, |  | ||||||
|     }, |  | ||||||
|     "outpost_connection_discovery": { |  | ||||||
|         "task": "authentik.outposts.tasks.outpost_connection_discovery", |  | ||||||
|         "schedule": crontab(minute=fqdn_rand("outpost_connection_discovery"), hour="*/8"), |  | ||||||
|         "options": {"queue": "authentik_scheduled"}, |  | ||||||
|     }, |  | ||||||
| } |  | ||||||
| @ -1,7 +1,6 @@ | |||||||
| """authentik outpost signals""" | """authentik outpost signals""" | ||||||
|  |  | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.db.models import Model |  | ||||||
| from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save | from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| @ -9,27 +8,19 @@ from structlog.stdlib import get_logger | |||||||
| from authentik.brands.models import Brand | from authentik.brands.models import Brand | ||||||
| from authentik.core.models import AuthenticatedSession, Provider | from authentik.core.models import AuthenticatedSession, Provider | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.lib.utils.reflection import class_to_path | from authentik.outposts.models import Outpost, OutpostModel, OutpostServiceConnection | ||||||
| from authentik.outposts.models import Outpost, OutpostServiceConnection |  | ||||||
| from authentik.outposts.tasks import ( | from authentik.outposts.tasks import ( | ||||||
|     CACHE_KEY_OUTPOST_DOWN, |     CACHE_KEY_OUTPOST_DOWN, | ||||||
|     outpost_controller, |     outpost_controller, | ||||||
|     outpost_post_save, |     outpost_send_update, | ||||||
|     outpost_session_end, |     outpost_session_end, | ||||||
| ) | ) | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| UPDATE_TRIGGERING_MODELS = ( |  | ||||||
|     Outpost, |  | ||||||
|     OutpostServiceConnection, |  | ||||||
|     Provider, |  | ||||||
|     CertificateKeyPair, |  | ||||||
|     Brand, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_save, sender=Outpost) | @receiver(pre_save, sender=Outpost) | ||||||
| def pre_save_outpost(sender, instance: Outpost, **_): | def outpost_pre_save(sender, instance: Outpost, **_): | ||||||
|     """Pre-save checks for an outpost, if the name or config.kubernetes_namespace changes, |     """Pre-save checks for an outpost, if the name or config.kubernetes_namespace changes, | ||||||
|     we call down and then wait for the up after save""" |     we call down and then wait for the up after save""" | ||||||
|     old_instances = Outpost.objects.filter(pk=instance.pk) |     old_instances = Outpost.objects.filter(pk=instance.pk) | ||||||
| @ -44,43 +35,89 @@ def pre_save_outpost(sender, instance: Outpost, **_): | |||||||
|     if bool(dirty): |     if bool(dirty): | ||||||
|         LOGGER.info("Outpost needs re-deployment due to changes", instance=instance) |         LOGGER.info("Outpost needs re-deployment due to changes", instance=instance) | ||||||
|         cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance) |         cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance) | ||||||
|         outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) |         outpost_controller.send_with_options( | ||||||
|  |             args=(instance.pk.hex,), | ||||||
|  |             kwargs={"action": "down", "from_cache": True}, | ||||||
|  |             rel_obj=instance, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(m2m_changed, sender=Outpost.providers.through) | @receiver(m2m_changed, sender=Outpost.providers.through) | ||||||
| def m2m_changed_update(sender, instance: Model, action: str, **_): | def outpost_m2m_changed(sender, instance: Outpost | Provider, action: str, **_): | ||||||
|     """Update outpost on m2m change, when providers are added or removed""" |     """Update outpost on m2m change, when providers are added or removed""" | ||||||
|     if action in ["post_add", "post_remove", "post_clear"]: |     if action not in ["post_add", "post_remove", "post_clear"]: | ||||||
|         outpost_post_save.delay(class_to_path(instance.__class__), instance.pk) |         return | ||||||
|  |     if isinstance(instance, Outpost): | ||||||
|  |         outpost_controller.send_with_options( | ||||||
|  |             args=(instance.pk,), | ||||||
|  |             rel_obj=instance.service_connection, | ||||||
|  |         ) | ||||||
|  |         outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance) | ||||||
|  |     elif isinstance(instance, OutpostModel): | ||||||
|  |         for outpost in instance.outpost_set.all(): | ||||||
|  |             outpost_controller.send_with_options( | ||||||
|  |                 args=(instance.pk,), | ||||||
|  |                 rel_obj=instance.service_connection, | ||||||
|  |             ) | ||||||
|  |             outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(post_save) | @receiver(post_save, sender=Outpost) | ||||||
| def post_save_update(sender, instance: Model, created: bool, **_): | def outpost_post_save(sender, instance: Outpost, created: bool, **_): | ||||||
|     """If an Outpost is saved, Ensure that token is created/updated |     if created: | ||||||
|  |  | ||||||
|     If an OutpostModel, or a model that is somehow connected to an OutpostModel is saved, |  | ||||||
|     we send a message down the relevant OutpostModels WS connection to trigger an update""" |  | ||||||
|     if instance.__module__ == "django.db.migrations.recorder": |  | ||||||
|         return |  | ||||||
|     if instance.__module__ == "__fake__": |  | ||||||
|         return |  | ||||||
|     if not isinstance(instance, UPDATE_TRIGGERING_MODELS): |  | ||||||
|         return |  | ||||||
|     if isinstance(instance, Outpost) and created: |  | ||||||
|         LOGGER.info("New outpost saved, ensuring initial token and user are created") |         LOGGER.info("New outpost saved, ensuring initial token and user are created") | ||||||
|         _ = instance.token |         _ = instance.token | ||||||
|     outpost_post_save.delay(class_to_path(instance.__class__), instance.pk) |     outpost_controller.send_with_options(args=(instance.pk,), rel_obj=instance.service_connection) | ||||||
|  |     outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def outpost_related_post_save(sender, instance: OutpostServiceConnection | OutpostModel, **_): | ||||||
|  |     for outpost in instance.outpost_set.all(): | ||||||
|  |         outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | post_save.connect(outpost_related_post_save, sender=OutpostServiceConnection, weak=False) | ||||||
|  | for subclass in OutpostModel.__subclasses__(): | ||||||
|  |     post_save.connect(outpost_related_post_save, sender=subclass, weak=False) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def outpost_reverse_related_post_save(sender, instance: CertificateKeyPair | Brand, **_): | ||||||
|  |     for field in instance._meta.get_fields(): | ||||||
|  |         # Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms) | ||||||
|  |         # are used, and if it has a value | ||||||
|  |         if not hasattr(field, "related_model"): | ||||||
|  |             continue | ||||||
|  |         if not field.related_model: | ||||||
|  |             continue | ||||||
|  |         if not issubclass(field.related_model, OutpostModel): | ||||||
|  |             continue | ||||||
|  |  | ||||||
|  |         field_name = f"{field.name}_set" | ||||||
|  |         if not hasattr(instance, field_name): | ||||||
|  |             continue | ||||||
|  |  | ||||||
|  |         LOGGER.debug("triggering outpost update from field", field=field.name) | ||||||
|  |         # Because the Outpost Model has an M2M to Provider, | ||||||
|  |         # we have to iterate over the entire QS | ||||||
|  |         for reverse in getattr(instance, field_name).all(): | ||||||
|  |             if isinstance(reverse, OutpostModel): | ||||||
|  |                 for outpost in reverse.outpost_set.all(): | ||||||
|  |                     outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | post_save.connect(outpost_reverse_related_post_save, sender=Brand, weak=False) | ||||||
|  | post_save.connect(outpost_reverse_related_post_save, sender=CertificateKeyPair, weak=False) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_delete, sender=Outpost) | @receiver(pre_delete, sender=Outpost) | ||||||
| def pre_delete_cleanup(sender, instance: Outpost, **_): | def outpost_pre_delete_cleanup(sender, instance: Outpost, **_): | ||||||
|     """Ensure that Outpost's user is deleted (which will delete the token through cascade)""" |     """Ensure that Outpost's user is deleted (which will delete the token through cascade)""" | ||||||
|     instance.user.delete() |     instance.user.delete() | ||||||
|     cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance) |     cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance) | ||||||
|     outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) |     outpost_controller.send(instance.pk.hex, action="down", from_cache=True) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_delete, sender=AuthenticatedSession) | @receiver(pre_delete, sender=AuthenticatedSession) | ||||||
| def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | def outpost_logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | ||||||
|     """Catch logout by expiring sessions being deleted""" |     """Catch logout by expiring sessions being deleted""" | ||||||
|     outpost_session_end.delay(instance.session.session_key) |     outpost_session_end.send(instance.session.session_key) | ||||||
|  | |||||||
| @ -10,19 +10,17 @@ from urllib.parse import urlparse | |||||||
| from asgiref.sync import async_to_sync | from asgiref.sync import async_to_sync | ||||||
| from channels.layers import get_channel_layer | from channels.layers import get_channel_layer | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.db import DatabaseError, InternalError, ProgrammingError |  | ||||||
| from django.db.models.base import Model |  | ||||||
| from django.utils.text import slugify | from django.utils.text import slugify | ||||||
|  | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from django_dramatiq_postgres.middleware import CurrentTask | ||||||
| from docker.constants import DEFAULT_UNIX_SOCKET | from docker.constants import DEFAULT_UNIX_SOCKET | ||||||
|  | from dramatiq.actor import actor | ||||||
| from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME | from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME | ||||||
| from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION | from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| from yaml import safe_load | from yaml import safe_load | ||||||
|  |  | ||||||
| from authentik.events.models import TaskStatus |  | ||||||
| from authentik.events.system_tasks import SystemTask, prefill_task |  | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.lib.utils.reflection import path_to_class |  | ||||||
| from authentik.outposts.consumer import OUTPOST_GROUP | from authentik.outposts.consumer import OUTPOST_GROUP | ||||||
| from authentik.outposts.controllers.base import BaseController, ControllerException | from authentik.outposts.controllers.base import BaseController, ControllerException | ||||||
| from authentik.outposts.controllers.docker import DockerClient | from authentik.outposts.controllers.docker import DockerClient | ||||||
| @ -31,7 +29,6 @@ from authentik.outposts.models import ( | |||||||
|     DockerServiceConnection, |     DockerServiceConnection, | ||||||
|     KubernetesServiceConnection, |     KubernetesServiceConnection, | ||||||
|     Outpost, |     Outpost, | ||||||
|     OutpostModel, |  | ||||||
|     OutpostServiceConnection, |     OutpostServiceConnection, | ||||||
|     OutpostType, |     OutpostType, | ||||||
|     ServiceConnectionInvalid, |     ServiceConnectionInvalid, | ||||||
| @ -44,7 +41,7 @@ from authentik.providers.rac.controllers.docker import RACDockerController | |||||||
| from authentik.providers.rac.controllers.kubernetes import RACKubernetesController | from authentik.providers.rac.controllers.kubernetes import RACKubernetesController | ||||||
| from authentik.providers.radius.controllers.docker import RadiusDockerController | from authentik.providers.radius.controllers.docker import RadiusDockerController | ||||||
| from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController | from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.tasks.models import Task | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s" | CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s" | ||||||
| @ -83,8 +80,8 @@ def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None: | |||||||
|     return None |     return None | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task() | @actor(description=_("Update cached state of service connection.")) | ||||||
| def outpost_service_connection_state(connection_pk: Any): | def outpost_service_connection_monitor(connection_pk: Any): | ||||||
|     """Update cached state of a service connection""" |     """Update cached state of a service connection""" | ||||||
|     connection: OutpostServiceConnection = ( |     connection: OutpostServiceConnection = ( | ||||||
|         OutpostServiceConnection.objects.filter(pk=connection_pk).select_subclasses().first() |         OutpostServiceConnection.objects.filter(pk=connection_pk).select_subclasses().first() | ||||||
| @ -108,37 +105,11 @@ def outpost_service_connection_state(connection_pk: Any): | |||||||
|     cache.set(connection.state_key, state, timeout=None) |     cache.set(connection.state_key, state, timeout=None) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task( | @actor(description=_("Create/update/monitor/delete the deployment of an Outpost.")) | ||||||
|     bind=True, | def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False): | ||||||
|     base=SystemTask, |  | ||||||
|     throws=(DatabaseError, ProgrammingError, InternalError), |  | ||||||
| ) |  | ||||||
| @prefill_task |  | ||||||
| def outpost_service_connection_monitor(self: SystemTask): |  | ||||||
|     """Regularly check the state of Outpost Service Connections""" |  | ||||||
|     connections = OutpostServiceConnection.objects.all() |  | ||||||
|     for connection in connections.iterator(): |  | ||||||
|         outpost_service_connection_state.delay(connection.pk) |  | ||||||
|     self.set_status( |  | ||||||
|         TaskStatus.SUCCESSFUL, |  | ||||||
|         f"Successfully updated {len(connections)} connections.", |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task( |  | ||||||
|     throws=(DatabaseError, ProgrammingError, InternalError), |  | ||||||
| ) |  | ||||||
| def outpost_controller_all(): |  | ||||||
|     """Launch Controller for all Outposts which support it""" |  | ||||||
|     for outpost in Outpost.objects.exclude(service_connection=None): |  | ||||||
|         outpost_controller.delay(outpost.pk.hex, "up", from_cache=False) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=SystemTask) |  | ||||||
| def outpost_controller( |  | ||||||
|     self: SystemTask, outpost_pk: str, action: str = "up", from_cache: bool = False |  | ||||||
| ): |  | ||||||
|     """Create/update/monitor/delete the deployment of an Outpost""" |     """Create/update/monitor/delete the deployment of an Outpost""" | ||||||
|  |     self: Task = CurrentTask.get_task() | ||||||
|  |     self.set_uid(outpost_pk) | ||||||
|     logs = [] |     logs = [] | ||||||
|     if from_cache: |     if from_cache: | ||||||
|         outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk) |         outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk) | ||||||
| @ -159,125 +130,65 @@ def outpost_controller( | |||||||
|             logs = getattr(controller, f"{action}_with_logs")() |             logs = getattr(controller, f"{action}_with_logs")() | ||||||
|             LOGGER.debug("-----------------Outpost Controller logs end-------------------") |             LOGGER.debug("-----------------Outpost Controller logs end-------------------") | ||||||
|     except (ControllerException, ServiceConnectionInvalid) as exc: |     except (ControllerException, ServiceConnectionInvalid) as exc: | ||||||
|         self.set_error(exc) |         self.error(exc) | ||||||
|     else: |     else: | ||||||
|         if from_cache: |         if from_cache: | ||||||
|             cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk) |             cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk) | ||||||
|         self.set_status(TaskStatus.SUCCESSFUL, *logs) |         self.logs(logs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=SystemTask) | @actor(description=_("Ensure that all Outposts have valid Service Accounts and Tokens.")) | ||||||
| @prefill_task | def outpost_token_ensurer(): | ||||||
| def outpost_token_ensurer(self: SystemTask): |     """ | ||||||
|     """Periodically ensure that all Outposts have valid Service Accounts |     Periodically ensure that all Outposts have valid Service Accounts and Tokens | ||||||
|     and Tokens""" |     """ | ||||||
|  |     self: Task = CurrentTask.get_task() | ||||||
|     all_outposts = Outpost.objects.all() |     all_outposts = Outpost.objects.all() | ||||||
|     for outpost in all_outposts: |     for outpost in all_outposts: | ||||||
|         _ = outpost.token |         _ = outpost.token | ||||||
|         outpost.build_user_permissions(outpost.user) |         outpost.build_user_permissions(outpost.user) | ||||||
|     self.set_status( |     self.info(f"Successfully checked {len(all_outposts)} Outposts.") | ||||||
|         TaskStatus.SUCCESSFUL, |  | ||||||
|         f"Successfully checked {len(all_outposts)} Outposts.", |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task() | @actor(description=_("Send update to outpost")) | ||||||
| def outpost_post_save(model_class: str, model_pk: Any): | def outpost_send_update(pk: Any): | ||||||
|     """If an Outpost is saved, Ensure that token is created/updated |     """Update outpost instance""" | ||||||
|  |     outpost = Outpost.objects.filter(pk=pk).first() | ||||||
|     If an OutpostModel, or a model that is somehow connected to an OutpostModel is saved, |     if not outpost: | ||||||
|     we send a message down the relevant OutpostModels WS connection to trigger an update""" |  | ||||||
|     model: Model = path_to_class(model_class) |  | ||||||
|     try: |  | ||||||
|         instance = model.objects.get(pk=model_pk) |  | ||||||
|     except model.DoesNotExist: |  | ||||||
|         LOGGER.warning("Model does not exist", model=model, pk=model_pk) |  | ||||||
|         return |         return | ||||||
|  |  | ||||||
|     if isinstance(instance, Outpost): |  | ||||||
|         LOGGER.debug("Trigger reconcile for outpost", instance=instance) |  | ||||||
|         outpost_controller.delay(str(instance.pk)) |  | ||||||
|  |  | ||||||
|     if isinstance(instance, OutpostModel | Outpost): |  | ||||||
|         LOGGER.debug("triggering outpost update from outpostmodel/outpost", instance=instance) |  | ||||||
|         outpost_send_update(instance) |  | ||||||
|  |  | ||||||
|     if isinstance(instance, OutpostServiceConnection): |  | ||||||
|         LOGGER.debug("triggering ServiceConnection state update", instance=instance) |  | ||||||
|         outpost_service_connection_state.delay(str(instance.pk)) |  | ||||||
|  |  | ||||||
|     for field in instance._meta.get_fields(): |  | ||||||
|         # Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms) |  | ||||||
|         # are used, and if it has a value |  | ||||||
|         if not hasattr(field, "related_model"): |  | ||||||
|             continue |  | ||||||
|         if not field.related_model: |  | ||||||
|             continue |  | ||||||
|         if not issubclass(field.related_model, OutpostModel): |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         field_name = f"{field.name}_set" |  | ||||||
|         if not hasattr(instance, field_name): |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         LOGGER.debug("triggering outpost update from field", field=field.name) |  | ||||||
|         # Because the Outpost Model has an M2M to Provider, |  | ||||||
|         # we have to iterate over the entire QS |  | ||||||
|         for reverse in getattr(instance, field_name).all(): |  | ||||||
|             outpost_send_update(reverse) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def outpost_send_update(model_instance: Model): |  | ||||||
|     """Send outpost update to all registered outposts, regardless to which authentik |  | ||||||
|     instance they are connected""" |  | ||||||
|     channel_layer = get_channel_layer() |  | ||||||
|     if isinstance(model_instance, OutpostModel): |  | ||||||
|         for outpost in model_instance.outpost_set.all(): |  | ||||||
|             _outpost_single_update(outpost, channel_layer) |  | ||||||
|     elif isinstance(model_instance, Outpost): |  | ||||||
|         _outpost_single_update(model_instance, channel_layer) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _outpost_single_update(outpost: Outpost, layer=None): |  | ||||||
|     """Update outpost instances connected to a single outpost""" |  | ||||||
|     # Ensure token again, because this function is called when anything related to an |     # Ensure token again, because this function is called when anything related to an | ||||||
|     # OutpostModel is saved, so we can be sure permissions are right |     # OutpostModel is saved, so we can be sure permissions are right | ||||||
|     _ = outpost.token |     _ = outpost.token | ||||||
|     outpost.build_user_permissions(outpost.user) |     outpost.build_user_permissions(outpost.user) | ||||||
|     if not layer:  # pragma: no cover |  | ||||||
|     layer = get_channel_layer() |     layer = get_channel_layer() | ||||||
|     group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} |     group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} | ||||||
|     LOGGER.debug("sending update", channel=group, outpost=outpost) |     LOGGER.debug("sending update", channel=group, outpost=outpost) | ||||||
|     async_to_sync(layer.group_send)(group, {"type": "event.update"}) |     async_to_sync(layer.group_send)(group, {"type": "event.update"}) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task( | @actor(description=_("Checks the local environment and create Service connections.")) | ||||||
|     base=SystemTask, | def outpost_connection_discovery(): | ||||||
|     bind=True, |  | ||||||
| ) |  | ||||||
| def outpost_connection_discovery(self: SystemTask): |  | ||||||
|     """Checks the local environment and create Service connections.""" |     """Checks the local environment and create Service connections.""" | ||||||
|     messages = [] |     self: Task = CurrentTask.get_task() | ||||||
|     if not CONFIG.get_bool("outposts.discover"): |     if not CONFIG.get_bool("outposts.discover"): | ||||||
|         messages.append("Outpost integration discovery is disabled") |         self.info("Outpost integration discovery is disabled") | ||||||
|         self.set_status(TaskStatus.SUCCESSFUL, *messages) |  | ||||||
|         return |         return | ||||||
|     # Explicitly check against token filename, as that's |     # Explicitly check against token filename, as that's | ||||||
|     # only present when the integration is enabled |     # only present when the integration is enabled | ||||||
|     if Path(SERVICE_TOKEN_FILENAME).exists(): |     if Path(SERVICE_TOKEN_FILENAME).exists(): | ||||||
|         messages.append("Detected in-cluster Kubernetes Config") |         self.info("Detected in-cluster Kubernetes Config") | ||||||
|         if not KubernetesServiceConnection.objects.filter(local=True).exists(): |         if not KubernetesServiceConnection.objects.filter(local=True).exists(): | ||||||
|             messages.append("Created Service Connection for in-cluster") |             self.info("Created Service Connection for in-cluster") | ||||||
|             KubernetesServiceConnection.objects.create( |             KubernetesServiceConnection.objects.create( | ||||||
|                 name="Local Kubernetes Cluster", local=True, kubeconfig={} |                 name="Local Kubernetes Cluster", local=True, kubeconfig={} | ||||||
|             ) |             ) | ||||||
|     # For development, check for the existence of a kubeconfig file |     # For development, check for the existence of a kubeconfig file | ||||||
|     kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser() |     kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser() | ||||||
|     if kubeconfig_path.exists(): |     if kubeconfig_path.exists(): | ||||||
|         messages.append("Detected kubeconfig") |         self.info("Detected kubeconfig") | ||||||
|         kubeconfig_local_name = f"k8s-{gethostname()}" |         kubeconfig_local_name = f"k8s-{gethostname()}" | ||||||
|         if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists(): |         if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists(): | ||||||
|             messages.append("Creating kubeconfig Service Connection") |             self.info("Creating kubeconfig Service Connection") | ||||||
|             with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig: |             with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig: | ||||||
|                 KubernetesServiceConnection.objects.create( |                 KubernetesServiceConnection.objects.create( | ||||||
|                     name=kubeconfig_local_name, |                     name=kubeconfig_local_name, | ||||||
| @ -286,20 +197,18 @@ def outpost_connection_discovery(self: SystemTask): | |||||||
|     unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path |     unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path | ||||||
|     socket = Path(unix_socket_path) |     socket = Path(unix_socket_path) | ||||||
|     if socket.exists() and access(socket, R_OK): |     if socket.exists() and access(socket, R_OK): | ||||||
|         messages.append("Detected local docker socket") |         self.info("Detected local docker socket") | ||||||
|         if len(DockerServiceConnection.objects.filter(local=True)) == 0: |         if len(DockerServiceConnection.objects.filter(local=True)) == 0: | ||||||
|             messages.append("Created Service Connection for docker") |             self.info("Created Service Connection for docker") | ||||||
|             DockerServiceConnection.objects.create( |             DockerServiceConnection.objects.create( | ||||||
|                 name="Local Docker connection", |                 name="Local Docker connection", | ||||||
|                 local=True, |                 local=True, | ||||||
|                 url=unix_socket_path, |                 url=unix_socket_path, | ||||||
|             ) |             ) | ||||||
|     self.set_status(TaskStatus.SUCCESSFUL, *messages) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task() | @actor(description=_("Terminate session on all outposts.")) | ||||||
| def outpost_session_end(session_id: str): | def outpost_session_end(session_id: str): | ||||||
|     """Update outpost instances connected to a single outpost""" |  | ||||||
|     layer = get_channel_layer() |     layer = get_channel_layer() | ||||||
|     hashed_session_id = hash_session_key(session_id) |     hashed_session_id = hash_session_key(session_id) | ||||||
|     for outpost in Outpost.objects.all(): |     for outpost in Outpost.objects.all(): | ||||||
|  | |||||||
| @ -37,6 +37,7 @@ class OutpostTests(TestCase): | |||||||
|  |  | ||||||
|         # We add a provider, user should only have access to outpost and provider |         # We add a provider, user should only have access to outpost and provider | ||||||
|         outpost.providers.add(provider) |         outpost.providers.add(provider) | ||||||
|  |         provider.refresh_from_db() | ||||||
|         permissions = UserObjectPermission.objects.filter(user=outpost.user).order_by( |         permissions = UserObjectPermission.objects.filter(user=outpost.user).order_by( | ||||||
|             "content_type__model" |             "content_type__model" | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -15,6 +15,7 @@ class AuthentikProviderProxyConfig(ManagedAppConfig): | |||||||
|     def proxy_set_defaults(self): |     def proxy_set_defaults(self): | ||||||
|         from authentik.providers.proxy.models import ProxyProvider |         from authentik.providers.proxy.models import ProxyProvider | ||||||
|  |  | ||||||
|  |         # TODO: figure out if this can be in pre_save + post_save signals | ||||||
|         for provider in ProxyProvider.objects.all(): |         for provider in ProxyProvider.objects.all(): | ||||||
|             provider.set_oauth_defaults() |             provider.set_oauth_defaults() | ||||||
|             provider.save() |             provider.save() | ||||||
|  | |||||||
							
								
								
									
										13
									
								
								authentik/providers/proxy/signals.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								authentik/providers/proxy/signals.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,13 @@ | |||||||
|  | """Proxy provider signals""" | ||||||
|  |  | ||||||
|  | from django.db.models.signals import pre_delete | ||||||
|  | from django.dispatch import receiver | ||||||
|  |  | ||||||
|  | from authentik.core.models import AuthenticatedSession | ||||||
|  | from authentik.providers.proxy.tasks import proxy_on_logout | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(pre_delete, sender=AuthenticatedSession) | ||||||
|  | def logout_proxy_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | ||||||
|  |     """Catch logout by expiring sessions being deleted""" | ||||||
|  |     proxy_on_logout.send(instance.session.session_key) | ||||||
							
								
								
									
										26
									
								
								authentik/providers/proxy/tasks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								authentik/providers/proxy/tasks.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,26 @@ | |||||||
|  | """proxy provider tasks""" | ||||||
|  |  | ||||||
|  | from asgiref.sync import async_to_sync | ||||||
|  | from channels.layers import get_channel_layer | ||||||
|  | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from dramatiq.actor import actor | ||||||
|  |  | ||||||
|  | from authentik.outposts.consumer import OUTPOST_GROUP | ||||||
|  | from authentik.outposts.models import Outpost, OutpostType | ||||||
|  | from authentik.providers.oauth2.id_token import hash_session_key | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @actor(description=_("Terminate session on Proxy outpost.")) | ||||||
|  | def proxy_on_logout(session_id: str): | ||||||
|  |     layer = get_channel_layer() | ||||||
|  |     hashed_session_id = hash_session_key(session_id) | ||||||
|  |     for outpost in Outpost.objects.filter(type=OutpostType.PROXY): | ||||||
|  |         group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} | ||||||
|  |         async_to_sync(layer.group_send)( | ||||||
|  |             group, | ||||||
|  |             { | ||||||
|  |                 "type": "event.provider.specific", | ||||||
|  |                 "sub_type": "logout", | ||||||
|  |                 "session_id": hashed_session_id, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
| @ -17,6 +17,7 @@ from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User | |||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.lib.models import SerializerModel | from authentik.lib.models import SerializerModel | ||||||
| from authentik.lib.utils.time import timedelta_string_validator | from authentik.lib.utils.time import timedelta_string_validator | ||||||
|  | from authentik.outposts.models import OutpostModel | ||||||
| from authentik.policies.models import PolicyBindingModel | from authentik.policies.models import PolicyBindingModel | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -37,7 +38,7 @@ class AuthenticationMode(models.TextChoices): | |||||||
|     PROMPT = "prompt" |     PROMPT = "prompt" | ||||||
|  |  | ||||||
|  |  | ||||||
| class RACProvider(Provider): | class RACProvider(OutpostModel, Provider): | ||||||
|     """Remotely access computers/servers via RDP/SSH/VNC.""" |     """Remotely access computers/servers via RDP/SSH/VNC.""" | ||||||
|  |  | ||||||
|     settings = models.JSONField(default=dict) |     settings = models.JSONField(default=dict) | ||||||
|  | |||||||
| @ -44,5 +44,5 @@ class SCIMProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin, ModelVie | |||||||
|     filterset_fields = ["name", "exclude_users_service_account", "url", "filter_group"] |     filterset_fields = ["name", "exclude_users_service_account", "url", "filter_group"] | ||||||
|     search_fields = ["name", "url"] |     search_fields = ["name", "url"] | ||||||
|     ordering = ["name", "url"] |     ordering = ["name", "url"] | ||||||
|     sync_single_task = scim_sync |     sync_task = scim_sync | ||||||
|     sync_objects_task = scim_sync_objects |     sync_objects_task = scim_sync_objects | ||||||
|  | |||||||
| @ -3,7 +3,6 @@ | |||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.providers.scim.models import SCIMProvider | from authentik.providers.scim.models import SCIMProvider | ||||||
| from authentik.providers.scim.tasks import scim_sync, sync_tasks |  | ||||||
| from authentik.tenants.management import TenantCommand | from authentik.tenants.management import TenantCommand | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -21,4 +20,5 @@ class Command(TenantCommand): | |||||||
|             if not provider: |             if not provider: | ||||||
|                 LOGGER.warning("Provider does not exist", name=provider_name) |                 LOGGER.warning("Provider does not exist", name=provider_name) | ||||||
|                 continue |                 continue | ||||||
|             sync_tasks.trigger_single_task(provider, scim_sync).get() |             for schedule in provider.schedules.all(): | ||||||
|  |                 schedule.send().get_result() | ||||||
|  | |||||||
| @ -7,6 +7,7 @@ from django.db import models | |||||||
| from django.db.models import QuerySet | from django.db.models import QuerySet | ||||||
| from django.templatetags.static import static | from django.templatetags.static import static | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from dramatiq.actor import Actor | ||||||
| from rest_framework.serializers import Serializer | from rest_framework.serializers import Serializer | ||||||
|  |  | ||||||
| from authentik.core.models import BackchannelProvider, Group, PropertyMapping, User, UserTypes | from authentik.core.models import BackchannelProvider, Group, PropertyMapping, User, UserTypes | ||||||
| @ -99,6 +100,12 @@ class SCIMProvider(OutgoingSyncProvider, BackchannelProvider): | |||||||
|     def icon_url(self) -> str | None: |     def icon_url(self) -> str | None: | ||||||
|         return static("authentik/sources/scim.png") |         return static("authentik/sources/scim.png") | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def sync_actor(self) -> Actor: | ||||||
|  |         from authentik.providers.scim.tasks import scim_sync | ||||||
|  |  | ||||||
|  |         return scim_sync | ||||||
|  |  | ||||||
|     def client_for_model( |     def client_for_model( | ||||||
|         self, model: type[User | Group | SCIMProviderUser | SCIMProviderGroup] |         self, model: type[User | Group | SCIMProviderUser | SCIMProviderGroup] | ||||||
|     ) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]: |     ) -> BaseOutgoingSyncClient[User | Group, Any, Any, Self]: | ||||||
|  | |||||||
| @ -1,13 +0,0 @@ | |||||||
| """SCIM task Settings""" |  | ||||||
|  |  | ||||||
| from celery.schedules import crontab |  | ||||||
|  |  | ||||||
| from authentik.lib.utils.time import fqdn_rand |  | ||||||
|  |  | ||||||
| CELERY_BEAT_SCHEDULE = { |  | ||||||
|     "providers_scim_sync": { |  | ||||||
|         "task": "authentik.providers.scim.tasks.scim_sync_all", |  | ||||||
|         "schedule": crontab(minute=fqdn_rand("scim_sync_all"), hour="*/4"), |  | ||||||
|         "options": {"queue": "authentik_scheduled"}, |  | ||||||
|     }, |  | ||||||
| } |  | ||||||
| @ -2,11 +2,10 @@ | |||||||
|  |  | ||||||
| from authentik.lib.sync.outgoing.signals import register_signals | from authentik.lib.sync.outgoing.signals import register_signals | ||||||
| from authentik.providers.scim.models import SCIMProvider | from authentik.providers.scim.models import SCIMProvider | ||||||
| from authentik.providers.scim.tasks import scim_sync, scim_sync_direct, scim_sync_m2m | from authentik.providers.scim.tasks import scim_sync_direct_dispatch, scim_sync_m2m_dispatch | ||||||
|  |  | ||||||
| register_signals( | register_signals( | ||||||
|     SCIMProvider, |     SCIMProvider, | ||||||
|     task_sync_single=scim_sync, |     task_sync_direct_dispatch=scim_sync_direct_dispatch, | ||||||
|     task_sync_direct=scim_sync_direct, |     task_sync_m2m_dispatch=scim_sync_m2m_dispatch, | ||||||
|     task_sync_m2m=scim_sync_m2m, |  | ||||||
| ) | ) | ||||||
|  | |||||||
| @ -1,37 +1,40 @@ | |||||||
| """SCIM Provider tasks""" | """SCIM Provider tasks""" | ||||||
|  |  | ||||||
| from authentik.events.system_tasks import SystemTask | from django.utils.translation import gettext_lazy as _ | ||||||
| from authentik.lib.sync.outgoing.exceptions import TransientSyncException | from dramatiq.actor import actor | ||||||
|  |  | ||||||
| from authentik.lib.sync.outgoing.tasks import SyncTasks | from authentik.lib.sync.outgoing.tasks import SyncTasks | ||||||
| from authentik.providers.scim.models import SCIMProvider | from authentik.providers.scim.models import SCIMProvider | ||||||
| from authentik.root.celery import CELERY_APP |  | ||||||
|  |  | ||||||
| sync_tasks = SyncTasks(SCIMProvider) | sync_tasks = SyncTasks(SCIMProvider) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | @actor(description=_("Sync SCIM provider objects.")) | ||||||
| def scim_sync_objects(*args, **kwargs): | def scim_sync_objects(*args, **kwargs): | ||||||
|     return sync_tasks.sync_objects(*args, **kwargs) |     return sync_tasks.sync_objects(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task( | @actor(description=_("Full sync for SCIM provider.")) | ||||||
|     base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True | def scim_sync(provider_pk: int, *args, **kwargs): | ||||||
| ) |  | ||||||
| def scim_sync(self, provider_pk: int, *args, **kwargs): |  | ||||||
|     """Run full sync for SCIM provider""" |     """Run full sync for SCIM provider""" | ||||||
|     return sync_tasks.sync_single(self, provider_pk, scim_sync_objects) |     return sync_tasks.sync(provider_pk, scim_sync_objects) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task() | @actor(description=_("Sync a direct object (user, group) for SCIM provider.")) | ||||||
| def scim_sync_all(): |  | ||||||
|     return sync_tasks.sync_all(scim_sync) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) |  | ||||||
| def scim_sync_direct(*args, **kwargs): | def scim_sync_direct(*args, **kwargs): | ||||||
|     return sync_tasks.sync_signal_direct(*args, **kwargs) |     return sync_tasks.sync_signal_direct(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | @actor(description=_("Dispatch syncs for a direct object (user, group) for SCIM providers.")) | ||||||
|  | def scim_sync_direct_dispatch(*args, **kwargs): | ||||||
|  |     return sync_tasks.sync_signal_direct_dispatch(scim_sync_direct, *args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @actor(description=_("Sync a related object (memberships) for SCIM provider.")) | ||||||
| def scim_sync_m2m(*args, **kwargs): | def scim_sync_m2m(*args, **kwargs): | ||||||
|     return sync_tasks.sync_signal_m2m(*args, **kwargs) |     return sync_tasks.sync_signal_m2m(*args, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @actor(description=_("Dispatch syncs for a related object (memberships) for SCIM providers.")) | ||||||
|  | def scim_sync_m2m_dispatch(*args, **kwargs): | ||||||
|  |     return sync_tasks.sync_signal_m2m_dispatch(scim_sync_m2m, *args, **kwargs) | ||||||
|  | |||||||
| @ -8,7 +8,7 @@ from authentik.core.models import Application | |||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.providers.scim.clients.base import SCIMClient | from authentik.providers.scim.clients.base import SCIMClient | ||||||
| from authentik.providers.scim.models import SCIMMapping, SCIMProvider | from authentik.providers.scim.models import SCIMMapping, SCIMProvider | ||||||
| from authentik.providers.scim.tasks import scim_sync_all | from authentik.providers.scim.tasks import scim_sync | ||||||
|  |  | ||||||
|  |  | ||||||
| class SCIMClientTests(TestCase): | class SCIMClientTests(TestCase): | ||||||
| @ -85,6 +85,6 @@ class SCIMClientTests(TestCase): | |||||||
|             self.assertEqual(mock.call_count, 1) |             self.assertEqual(mock.call_count, 1) | ||||||
|             self.assertEqual(mock.request_history[0].method, "GET") |             self.assertEqual(mock.request_history[0].method, "GET") | ||||||
|  |  | ||||||
|     def test_scim_sync_all(self): |     def test_scim_sync(self): | ||||||
|         """test scim_sync_all task""" |         """test scim_sync task""" | ||||||
|         scim_sync_all() |         scim_sync.send(self.provider.pk).get_result() | ||||||
|  | |||||||
| @ -8,7 +8,7 @@ from authentik.core.models import Application, Group, User | |||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.providers.scim.clients.schema import ServiceProviderConfiguration | from authentik.providers.scim.clients.schema import ServiceProviderConfiguration | ||||||
| from authentik.providers.scim.models import SCIMMapping, SCIMProvider | from authentik.providers.scim.models import SCIMMapping, SCIMProvider | ||||||
| from authentik.providers.scim.tasks import scim_sync, sync_tasks | from authentik.providers.scim.tasks import scim_sync | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -79,17 +79,15 @@ class SCIMMembershipTests(TestCase): | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             self.configure() |             self.configure() | ||||||
|             sync_tasks.trigger_single_task(self.provider, scim_sync).get() |             scim_sync.send(self.provider.pk) | ||||||
|  |  | ||||||
|             self.assertEqual(mocker.call_count, 6) |             self.assertEqual(mocker.call_count, 4) | ||||||
|             self.assertEqual(mocker.request_history[0].method, "GET") |             self.assertEqual(mocker.request_history[0].method, "GET") | ||||||
|             self.assertEqual(mocker.request_history[1].method, "GET") |             self.assertEqual(mocker.request_history[1].method, "POST") | ||||||
|             self.assertEqual(mocker.request_history[2].method, "GET") |             self.assertEqual(mocker.request_history[2].method, "GET") | ||||||
|             self.assertEqual(mocker.request_history[3].method, "POST") |             self.assertEqual(mocker.request_history[3].method, "POST") | ||||||
|             self.assertEqual(mocker.request_history[4].method, "GET") |  | ||||||
|             self.assertEqual(mocker.request_history[5].method, "POST") |  | ||||||
|             self.assertJSONEqual( |             self.assertJSONEqual( | ||||||
|                 mocker.request_history[3].body, |                 mocker.request_history[1].body, | ||||||
|                 { |                 { | ||||||
|                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], |                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], | ||||||
|                     "emails": [], |                     "emails": [], | ||||||
| @ -101,7 +99,7 @@ class SCIMMembershipTests(TestCase): | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|             self.assertJSONEqual( |             self.assertJSONEqual( | ||||||
|                 mocker.request_history[5].body, |                 mocker.request_history[3].body, | ||||||
|                 { |                 { | ||||||
|                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], |                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], | ||||||
|                     "externalId": str(group.pk), |                     "externalId": str(group.pk), | ||||||
| @ -169,17 +167,15 @@ class SCIMMembershipTests(TestCase): | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             self.configure() |             self.configure() | ||||||
|             sync_tasks.trigger_single_task(self.provider, scim_sync).get() |             scim_sync.send(self.provider.pk) | ||||||
|  |  | ||||||
|             self.assertEqual(mocker.call_count, 6) |             self.assertEqual(mocker.call_count, 4) | ||||||
|             self.assertEqual(mocker.request_history[0].method, "GET") |             self.assertEqual(mocker.request_history[0].method, "GET") | ||||||
|             self.assertEqual(mocker.request_history[1].method, "GET") |             self.assertEqual(mocker.request_history[1].method, "POST") | ||||||
|             self.assertEqual(mocker.request_history[2].method, "GET") |             self.assertEqual(mocker.request_history[2].method, "GET") | ||||||
|             self.assertEqual(mocker.request_history[3].method, "POST") |             self.assertEqual(mocker.request_history[3].method, "POST") | ||||||
|             self.assertEqual(mocker.request_history[4].method, "GET") |  | ||||||
|             self.assertEqual(mocker.request_history[5].method, "POST") |  | ||||||
|             self.assertJSONEqual( |             self.assertJSONEqual( | ||||||
|                 mocker.request_history[3].body, |                 mocker.request_history[1].body, | ||||||
|                 { |                 { | ||||||
|                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], |                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], | ||||||
|                     "active": True, |                     "active": True, | ||||||
| @ -191,7 +187,7 @@ class SCIMMembershipTests(TestCase): | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|             self.assertJSONEqual( |             self.assertJSONEqual( | ||||||
|                 mocker.request_history[5].body, |                 mocker.request_history[3].body, | ||||||
|                 { |                 { | ||||||
|                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], |                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], | ||||||
|                     "externalId": str(group.pk), |                     "externalId": str(group.pk), | ||||||
| @ -287,17 +283,15 @@ class SCIMMembershipTests(TestCase): | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             self.configure() |             self.configure() | ||||||
|             sync_tasks.trigger_single_task(self.provider, scim_sync).get() |             scim_sync.send(self.provider.pk) | ||||||
|  |  | ||||||
|             self.assertEqual(mocker.call_count, 6) |             self.assertEqual(mocker.call_count, 4) | ||||||
|             self.assertEqual(mocker.request_history[0].method, "GET") |             self.assertEqual(mocker.request_history[0].method, "GET") | ||||||
|             self.assertEqual(mocker.request_history[1].method, "GET") |             self.assertEqual(mocker.request_history[1].method, "POST") | ||||||
|             self.assertEqual(mocker.request_history[2].method, "GET") |             self.assertEqual(mocker.request_history[2].method, "GET") | ||||||
|             self.assertEqual(mocker.request_history[3].method, "POST") |             self.assertEqual(mocker.request_history[3].method, "POST") | ||||||
|             self.assertEqual(mocker.request_history[4].method, "GET") |  | ||||||
|             self.assertEqual(mocker.request_history[5].method, "POST") |  | ||||||
|             self.assertJSONEqual( |             self.assertJSONEqual( | ||||||
|                 mocker.request_history[3].body, |                 mocker.request_history[1].body, | ||||||
|                 { |                 { | ||||||
|                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], |                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], | ||||||
|                     "emails": [], |                     "emails": [], | ||||||
| @ -309,7 +303,7 @@ class SCIMMembershipTests(TestCase): | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|             self.assertJSONEqual( |             self.assertJSONEqual( | ||||||
|                 mocker.request_history[5].body, |                 mocker.request_history[3].body, | ||||||
|                 { |                 { | ||||||
|                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], |                     "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"], | ||||||
|                     "externalId": str(group.pk), |                     "externalId": str(group.pk), | ||||||
|  | |||||||
| @ -9,11 +9,11 @@ from requests_mock import Mocker | |||||||
|  |  | ||||||
| from authentik.blueprints.tests import apply_blueprint | from authentik.blueprints.tests import apply_blueprint | ||||||
| from authentik.core.models import Application, Group, User | from authentik.core.models import Application, Group, User | ||||||
| from authentik.events.models import SystemTask |  | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.lib.sync.outgoing.base import SAFE_METHODS | from authentik.lib.sync.outgoing.base import SAFE_METHODS | ||||||
| from authentik.providers.scim.models import SCIMMapping, SCIMProvider | from authentik.providers.scim.models import SCIMMapping, SCIMProvider | ||||||
| from authentik.providers.scim.tasks import scim_sync, sync_tasks | from authentik.providers.scim.tasks import scim_sync, scim_sync_objects | ||||||
|  | from authentik.tasks.models import Task | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -356,7 +356,7 @@ class SCIMUserTests(TestCase): | |||||||
|             email=f"{uid}@goauthentik.io", |             email=f"{uid}@goauthentik.io", | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         sync_tasks.trigger_single_task(self.provider, scim_sync).get() |         scim_sync.send(self.provider.pk) | ||||||
|  |  | ||||||
|         self.assertEqual(mock.call_count, 5) |         self.assertEqual(mock.call_count, 5) | ||||||
|         self.assertEqual(mock.request_history[0].method, "GET") |         self.assertEqual(mock.request_history[0].method, "GET") | ||||||
| @ -428,14 +428,19 @@ class SCIMUserTests(TestCase): | |||||||
|                 email=f"{uid}@goauthentik.io", |                 email=f"{uid}@goauthentik.io", | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             sync_tasks.trigger_single_task(self.provider, scim_sync).get() |             scim_sync.send(self.provider.pk) | ||||||
|  |  | ||||||
|             self.assertEqual(mock.call_count, 3) |             self.assertEqual(mock.call_count, 3) | ||||||
|             for request in mock.request_history: |             for request in mock.request_history: | ||||||
|                 self.assertIn(request.method, SAFE_METHODS) |                 self.assertIn(request.method, SAFE_METHODS) | ||||||
|         task = SystemTask.objects.filter(uid=slugify(self.provider.name)).first() |         task = list( | ||||||
|  |             Task.objects.filter( | ||||||
|  |                 actor_name=scim_sync_objects.actor_name, | ||||||
|  |                 _uid=slugify(self.provider.name), | ||||||
|  |             ).order_by("-mtime") | ||||||
|  |         )[1] | ||||||
|         self.assertIsNotNone(task) |         self.assertIsNotNone(task) | ||||||
|         drop_msg = task.messages[3] |         drop_msg = task._messages[3] | ||||||
|         self.assertEqual(drop_msg["event"], "Dropping mutating request due to dry run") |         self.assertEqual(drop_msg["event"], "Dropping mutating request due to dry run") | ||||||
|         self.assertIsNotNone(drop_msg["attributes"]["url"]) |         self.assertIsNotNone(drop_msg["attributes"]["url"]) | ||||||
|         self.assertIsNotNone(drop_msg["attributes"]["body"]) |         self.assertIsNotNone(drop_msg["attributes"]["body"]) | ||||||
|  | |||||||
| @ -1,7 +1,6 @@ | |||||||
| """authentik core celery""" | """authentik core celery""" | ||||||
| 
 | 
 | ||||||
| import os | import os | ||||||
| from collections.abc import Callable |  | ||||||
| from contextvars import ContextVar | from contextvars import ContextVar | ||||||
| from logging.config import dictConfig | from logging.config import dictConfig | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| @ -16,12 +15,9 @@ from celery.signals import ( | |||||||
|     task_internal_error, |     task_internal_error, | ||||||
|     task_postrun, |     task_postrun, | ||||||
|     task_prerun, |     task_prerun, | ||||||
|     worker_ready, |  | ||||||
| ) | ) | ||||||
| from celery.worker.control import inspect_command | from celery.worker.control import inspect_command | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.db import ProgrammingError |  | ||||||
| from django_tenants.utils import get_public_schema_name |  | ||||||
| from structlog.contextvars import STRUCTLOG_KEY_PREFIX | from structlog.contextvars import STRUCTLOG_KEY_PREFIX | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp | from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp | ||||||
| @ -69,7 +65,10 @@ def task_postrun_hook(task_id: str, task, *args, retval=None, state=None, **kwar | |||||||
|     """Log task_id on worker""" |     """Log task_id on worker""" | ||||||
|     CTX_TASK_ID.set(...) |     CTX_TASK_ID.set(...) | ||||||
|     LOGGER.info( |     LOGGER.info( | ||||||
|         "Task finished", task_id=task_id.replace("-", ""), task_name=task.__name__, state=state |         "Task finished", | ||||||
|  |         task_id=task_id.replace("-", ""), | ||||||
|  |         task_name=task.__name__, | ||||||
|  |         state=state, | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -83,51 +82,12 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar | |||||||
|     CTX_TASK_ID.set(...) |     CTX_TASK_ID.set(...) | ||||||
|     if not should_ignore_exception(exception): |     if not should_ignore_exception(exception): | ||||||
|         Event.new( |         Event.new( | ||||||
|             EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id |             EventAction.SYSTEM_EXCEPTION, | ||||||
|  |             message=exception_to_string(exception), | ||||||
|  |             task_id=task_id, | ||||||
|         ).save() |         ).save() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _get_startup_tasks_default_tenant() -> list[Callable]: |  | ||||||
|     """Get all tasks to be run on startup for the default tenant""" |  | ||||||
|     from authentik.outposts.tasks import outpost_connection_discovery |  | ||||||
| 
 |  | ||||||
|     return [ |  | ||||||
|         outpost_connection_discovery, |  | ||||||
|     ] |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def _get_startup_tasks_all_tenants() -> list[Callable]: |  | ||||||
|     """Get all tasks to be run on startup for all tenants""" |  | ||||||
|     return [] |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @worker_ready.connect |  | ||||||
| def worker_ready_hook(*args, **kwargs): |  | ||||||
|     """Run certain tasks on worker start""" |  | ||||||
|     from authentik.tenants.models import Tenant |  | ||||||
| 
 |  | ||||||
|     LOGGER.info("Dispatching startup tasks...") |  | ||||||
| 
 |  | ||||||
|     def _run_task(task: Callable): |  | ||||||
|         try: |  | ||||||
|             task.delay() |  | ||||||
|         except ProgrammingError as exc: |  | ||||||
|             LOGGER.warning("Startup task failed", task=task, exc=exc) |  | ||||||
| 
 |  | ||||||
|     for task in _get_startup_tasks_default_tenant(): |  | ||||||
|         with Tenant.objects.get(schema_name=get_public_schema_name()): |  | ||||||
|             _run_task(task) |  | ||||||
| 
 |  | ||||||
|     for task in _get_startup_tasks_all_tenants(): |  | ||||||
|         for tenant in Tenant.objects.filter(ready=True): |  | ||||||
|             with tenant: |  | ||||||
|                 _run_task(task) |  | ||||||
| 
 |  | ||||||
|     from authentik.blueprints.v1.tasks import start_blueprint_watcher |  | ||||||
| 
 |  | ||||||
|     start_blueprint_watcher() |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class LivenessProbe(bootsteps.StartStopStep): | class LivenessProbe(bootsteps.StartStopStep): | ||||||
|     """Add a timed task to touch a temporary file for healthchecking reasons""" |     """Add a timed task to touch a temporary file for healthchecking reasons""" | ||||||
| 
 | 
 | ||||||
| @ -4,9 +4,9 @@ import importlib | |||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| from hashlib import sha512 | from hashlib import sha512 | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  | from tempfile import gettempdir | ||||||
|  |  | ||||||
| import orjson | import orjson | ||||||
| from celery.schedules import crontab |  | ||||||
| from sentry_sdk import set_tag | from sentry_sdk import set_tag | ||||||
| from xmlsec import enable_debug_trace | from xmlsec import enable_debug_trace | ||||||
|  |  | ||||||
| @ -65,14 +65,18 @@ SHARED_APPS = [ | |||||||
|     "pgactivity", |     "pgactivity", | ||||||
|     "pglock", |     "pglock", | ||||||
|     "channels", |     "channels", | ||||||
|  |     "django_dramatiq_postgres", | ||||||
|  |     "authentik.tasks", | ||||||
| ] | ] | ||||||
| TENANT_APPS = [ | TENANT_APPS = [ | ||||||
|     "django.contrib.auth", |     "django.contrib.auth", | ||||||
|     "django.contrib.contenttypes", |     "django.contrib.contenttypes", | ||||||
|     "django.contrib.sessions", |     "django.contrib.sessions", | ||||||
|  |     "pgtrigger", | ||||||
|     "authentik.admin", |     "authentik.admin", | ||||||
|     "authentik.api", |     "authentik.api", | ||||||
|     "authentik.crypto", |     "authentik.crypto", | ||||||
|  |     "authentik.events", | ||||||
|     "authentik.flows", |     "authentik.flows", | ||||||
|     "authentik.outposts", |     "authentik.outposts", | ||||||
|     "authentik.policies.dummy", |     "authentik.policies.dummy", | ||||||
| @ -120,6 +124,7 @@ TENANT_APPS = [ | |||||||
|     "authentik.stages.user_login", |     "authentik.stages.user_login", | ||||||
|     "authentik.stages.user_logout", |     "authentik.stages.user_logout", | ||||||
|     "authentik.stages.user_write", |     "authentik.stages.user_write", | ||||||
|  |     "authentik.tasks.schedules", | ||||||
|     "authentik.brands", |     "authentik.brands", | ||||||
|     "authentik.blueprints", |     "authentik.blueprints", | ||||||
|     "guardian", |     "guardian", | ||||||
| @ -166,6 +171,7 @@ SPECTACULAR_SETTINGS = { | |||||||
|         "UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification", |         "UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification", | ||||||
|         "UserTypeEnum": "authentik.core.models.UserTypes", |         "UserTypeEnum": "authentik.core.models.UserTypes", | ||||||
|         "OutgoingSyncDeleteAction": "authentik.lib.sync.outgoing.models.OutgoingSyncDeleteAction", |         "OutgoingSyncDeleteAction": "authentik.lib.sync.outgoing.models.OutgoingSyncDeleteAction", | ||||||
|  |         "TaskAggregatedStatusEnum": "authentik.tasks.models.TaskStatus", | ||||||
|     }, |     }, | ||||||
|     "ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False, |     "ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False, | ||||||
|     "ENUM_GENERATE_CHOICE_DESCRIPTION": False, |     "ENUM_GENERATE_CHOICE_DESCRIPTION": False, | ||||||
| @ -341,37 +347,85 @@ USE_TZ = True | |||||||
|  |  | ||||||
| LOCALE_PATHS = ["./locale"] | LOCALE_PATHS = ["./locale"] | ||||||
|  |  | ||||||
| CELERY = { |  | ||||||
|     "task_soft_time_limit": 600, | # Tests | ||||||
|     "worker_max_tasks_per_child": 50, |  | ||||||
|     "worker_concurrency": CONFIG.get_int("worker.concurrency"), | TEST = False | ||||||
|     "beat_schedule": { | TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner" | ||||||
|         "clean_expired_models": { |  | ||||||
|             "task": "authentik.core.tasks.clean_expired_models", |  | ||||||
|             "schedule": crontab(minute="2-59/5"), | # Dramatiq | ||||||
|             "options": {"queue": "authentik_scheduled"}, |  | ||||||
|  | DRAMATIQ = { | ||||||
|  |     "broker_class": "authentik.tasks.broker.Broker", | ||||||
|  |     "channel_prefix": "authentik", | ||||||
|  |     "task_model": "authentik.tasks.models.Task", | ||||||
|  |     "task_purge_interval": timedelta_from_string( | ||||||
|  |         CONFIG.get("worker.task_purge_interval") | ||||||
|  |     ).total_seconds(), | ||||||
|  |     "task_expiration": timedelta_from_string(CONFIG.get("worker.task_expiration")).total_seconds(), | ||||||
|  |     "autodiscovery": { | ||||||
|  |         "enabled": True, | ||||||
|  |         "setup_module": "authentik.tasks.setup", | ||||||
|  |         "apps_prefix": "authentik", | ||||||
|     }, |     }, | ||||||
|         "user_cleanup": { |     "worker": { | ||||||
|             "task": "authentik.core.tasks.clean_temporary_users", |         "processes": CONFIG.get_int("worker.processes", 2), | ||||||
|             "schedule": crontab(minute="9-59/5"), |         "threads": CONFIG.get_int("worker.threads", 1), | ||||||
|             "options": {"queue": "authentik_scheduled"}, |         "consumer_listen_timeout": timedelta_from_string( | ||||||
|  |             CONFIG.get("worker.consumer_listen_timeout") | ||||||
|  |         ).total_seconds(), | ||||||
|  |         "watch_folder": BASE_DIR / "authentik", | ||||||
|     }, |     }, | ||||||
|  |     "scheduler_class": "authentik.tasks.schedules.scheduler.Scheduler", | ||||||
|  |     "schedule_model": "authentik.tasks.schedules.models.Schedule", | ||||||
|  |     "scheduler_interval": timedelta_from_string( | ||||||
|  |         CONFIG.get("worker.scheduler_interval") | ||||||
|  |     ).total_seconds(), | ||||||
|  |     "middlewares": ( | ||||||
|  |         ("django_dramatiq_postgres.middleware.FullyQualifiedActorName", {}), | ||||||
|  |         # TODO: fixme | ||||||
|  |         # ("dramatiq.middleware.prometheus.Prometheus", {}), | ||||||
|  |         ("django_dramatiq_postgres.middleware.DbConnectionMiddleware", {}), | ||||||
|  |         ("dramatiq.middleware.age_limit.AgeLimit", {}), | ||||||
|  |         ( | ||||||
|  |             "dramatiq.middleware.time_limit.TimeLimit", | ||||||
|  |             { | ||||||
|  |                 "time_limit": timedelta_from_string( | ||||||
|  |                     CONFIG.get("worker.task_default_time_limit") | ||||||
|  |                 ).total_seconds() | ||||||
|  |                 * 1000 | ||||||
|             }, |             }, | ||||||
|     "beat_scheduler": "authentik.tenants.scheduler:TenantAwarePersistentScheduler", |  | ||||||
|     "task_create_missing_queues": True, |  | ||||||
|     "task_default_queue": "authentik", |  | ||||||
|     "broker_url": CONFIG.get("broker.url") or redis_url(CONFIG.get("redis.db")), |  | ||||||
|     "result_backend": CONFIG.get("result_backend.url") or redis_url(CONFIG.get("redis.db")), |  | ||||||
|     "broker_transport_options": CONFIG.get_dict_from_b64_json( |  | ||||||
|         "broker.transport_options", {"retry_policy": {"timeout": 5.0}} |  | ||||||
|         ), |         ), | ||||||
|     "result_backend_transport_options": CONFIG.get_dict_from_b64_json( |         ("dramatiq.middleware.shutdown.ShutdownNotifications", {}), | ||||||
|         "result_backend.transport_options", {"retry_policy": {"timeout": 5.0}} |         ("dramatiq.middleware.callbacks.Callbacks", {}), | ||||||
|  |         ("dramatiq.middleware.pipelines.Pipelines", {}), | ||||||
|  |         ( | ||||||
|  |             "dramatiq.middleware.retries.Retries", | ||||||
|  |             {"max_retries": CONFIG.get_int("worker.task_max_retries") if not TEST else 0}, | ||||||
|         ), |         ), | ||||||
|     "redis_retry_on_timeout": True, |         ("dramatiq.results.middleware.Results", {"store_results": True}), | ||||||
|  |         ("django_dramatiq_postgres.middleware.CurrentTask", {}), | ||||||
|  |         ("authentik.tasks.middleware.TenantMiddleware", {}), | ||||||
|  |         ("authentik.tasks.middleware.RelObjMiddleware", {}), | ||||||
|  |         ("authentik.tasks.middleware.MessagesMiddleware", {}), | ||||||
|  |         ("authentik.tasks.middleware.LoggingMiddleware", {}), | ||||||
|  |         ("authentik.tasks.middleware.DescriptionMiddleware", {}), | ||||||
|  |         ("authentik.tasks.middleware.WorkerStatusMiddleware", {}), | ||||||
|  |         ( | ||||||
|  |             "authentik.tasks.middleware.MetricsMiddleware", | ||||||
|  |             { | ||||||
|  |                 "multiproc_dir": str(Path(gettempdir()) / "authentik_prometheus_tmp"), | ||||||
|  |                 "prefix": "authentik", | ||||||
|  |             }, | ||||||
|  |         ), | ||||||
|  |     ), | ||||||
|  |     "test": TEST, | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| # Sentry integration | # Sentry integration | ||||||
|  |  | ||||||
| env = get_env() | env = get_env() | ||||||
| _ERROR_REPORTING = CONFIG.get_bool("error_reporting.enabled", False) | _ERROR_REPORTING = CONFIG.get_bool("error_reporting.enabled", False) | ||||||
| if _ERROR_REPORTING: | if _ERROR_REPORTING: | ||||||
| @ -432,9 +486,6 @@ else: | |||||||
|     MEDIA_ROOT = STORAGES["default"]["OPTIONS"]["location"] |     MEDIA_ROOT = STORAGES["default"]["OPTIONS"]["location"] | ||||||
|     MEDIA_URL = STORAGES["default"]["OPTIONS"]["base_url"] |     MEDIA_URL = STORAGES["default"]["OPTIONS"]["base_url"] | ||||||
|  |  | ||||||
| TEST = False |  | ||||||
| TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner" |  | ||||||
|  |  | ||||||
| structlog_configure() | structlog_configure() | ||||||
| LOGGING = get_logger_config() | LOGGING = get_logger_config() | ||||||
|  |  | ||||||
| @ -445,7 +496,6 @@ _DISALLOWED_ITEMS = [ | |||||||
|     "INSTALLED_APPS", |     "INSTALLED_APPS", | ||||||
|     "MIDDLEWARE", |     "MIDDLEWARE", | ||||||
|     "AUTHENTICATION_BACKENDS", |     "AUTHENTICATION_BACKENDS", | ||||||
|     "CELERY", |  | ||||||
|     "SPECTACULAR_SETTINGS", |     "SPECTACULAR_SETTINGS", | ||||||
|     "REST_FRAMEWORK", |     "REST_FRAMEWORK", | ||||||
| ] | ] | ||||||
| @ -472,7 +522,6 @@ def _update_settings(app_path: str): | |||||||
|         AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", [])) |         AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", [])) | ||||||
|         SPECTACULAR_SETTINGS.update(getattr(settings_module, "SPECTACULAR_SETTINGS", {})) |         SPECTACULAR_SETTINGS.update(getattr(settings_module, "SPECTACULAR_SETTINGS", {})) | ||||||
|         REST_FRAMEWORK.update(getattr(settings_module, "REST_FRAMEWORK", {})) |         REST_FRAMEWORK.update(getattr(settings_module, "REST_FRAMEWORK", {})) | ||||||
|         CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {})) |  | ||||||
|         for _attr in dir(settings_module): |         for _attr in dir(settings_module): | ||||||
|             if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS: |             if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS: | ||||||
|                 globals()[_attr] = getattr(settings_module, _attr) |                 globals()[_attr] = getattr(settings_module, _attr) | ||||||
| @ -481,7 +530,6 @@ def _update_settings(app_path: str): | |||||||
|  |  | ||||||
|  |  | ||||||
| if DEBUG: | if DEBUG: | ||||||
|     CELERY["task_always_eager"] = True |  | ||||||
|     REST_FRAMEWORK["DEFAULT_RENDERER_CLASSES"].append( |     REST_FRAMEWORK["DEFAULT_RENDERER_CLASSES"].append( | ||||||
|         "rest_framework.renderers.BrowsableAPIRenderer" |         "rest_framework.renderers.BrowsableAPIRenderer" | ||||||
|     ) |     ) | ||||||
| @ -501,10 +549,6 @@ try: | |||||||
| except ImportError: | except ImportError: | ||||||
|     pass |     pass | ||||||
|  |  | ||||||
| # Import events after other apps since it relies on tasks and other things from all apps |  | ||||||
| # being imported for @prefill_task |  | ||||||
| TENANT_APPS.append("authentik.events") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # Load subapps's settings | # Load subapps's settings | ||||||
| for _app in set(SHARED_APPS + TENANT_APPS): | for _app in set(SHARED_APPS + TENANT_APPS): | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user