Compare commits
	
		
			296 Commits
		
	
	
		
			website/do
			...
			celery-2-d
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 80eb56b016 | |||
| 3a1d0fbd35 | |||
| c85471575a | |||
| 5d00dc7e9e | |||
| 6982e7d1c9 | |||
| c7fe987c5a | |||
| e48739c8a0 | |||
| b2ee585c43 | |||
| 97e8ea8e76 | |||
| 1f1e0c9db1 | |||
| ca47a803fe | |||
| c606eb53b0 | |||
| c2dc38e804 | |||
| 9aad7e29a4 | |||
| 62357133b0 | |||
| 99d2d91257 | |||
| 69d9363fce | |||
| 315f9073cc | |||
| c94fa13826 | |||
| 6b3fbb0abf | |||
| 06191d6cfc | |||
| c608dc5110 | |||
| bb05d6063d | |||
| c5b4a630c9 | |||
| 4a73c65710 | |||
| 51e645c26a | |||
| de3607b9a6 | |||
| fb5804b914 | |||
| 287f41fed8 | |||
| 172b595a9f | |||
| d61995d0e2 | |||
| cfc7f6b993 | |||
| 056460dac0 | |||
| bebbbe9b90 | |||
| 188d3c69c1 | |||
| 877f312145 | |||
| f471a98bc7 | |||
| b30b736a71 | |||
| d3aa43ced0 | |||
| 92c837f7b5 | |||
| 7fcd65e318 | |||
| e874cfc21d | |||
| ec7bdf74aa | |||
| 9c01c7d890 | |||
| bdb0564d4c | |||
| e87bc94b95 | |||
| 09688f7a55 | |||
| 31bb995490 | |||
| 5d813438f0 | |||
| a3865abaa9 | |||
| 7100d3c674 | |||
| d92764789f | |||
| c702fa1f95 | |||
| 786bada7d0 | |||
| c0c2d2ad3c | |||
| dc287989db | |||
| 03204f6943 | |||
| fcd369e466 | |||
| cb79407bc1 | |||
| 04a88daf34 | |||
| c6a49da5c3 | |||
| bfeeecf3fa | |||
| d86b5e7c8a | |||
| 690766d377 | |||
| 4990abdf4a | |||
| 76616cf1c5 | |||
| 48c0b5449e | |||
| 8956606564 | |||
| a95776891e | |||
| 0908d0b559 | |||
| 7e5c90fcdc | |||
| 4e210b8299 | |||
| bf65ca4c70 | |||
| 670a88659e | |||
| 0ebbaeea6f | |||
| 49a9911271 | |||
| f3c0ca1a59 | |||
| 64e7bff16c | |||
| f264989f9e | |||
| 031158fdba | |||
| b2fbb92498 | |||
| 8f4353181e | |||
| b1b6bf1a19 | |||
| 179d9d0721 | |||
| 8e94d58851 | |||
| 026669cfce | |||
| c83cea6963 | |||
| 8e01cc2df8 | |||
| 279cec203d | |||
| 41c5030c1e | |||
| 3206fdb7ef | |||
| d7c0868eef | |||
| 7d96a89697 | |||
| 813b7aa8ba | |||
| 3bb3e0d1ef | |||
| dfb0007777 | |||
| 101e5adeba | |||
| ae228c91e3 | |||
| 411c52491e | |||
| 85869806a2 | |||
| d132db475e | |||
| 13b5aa604b | |||
| 97a5acdff5 | |||
| ea38f2d120 | |||
| 94867aaebf | |||
| 0e67c1d818 | |||
| 2a460201bb | |||
| f99cb3e9fb | |||
| e4bd05f444 | |||
| 80c4eb9bef | |||
| 96b4d5aee4 | |||
| 6321537c8d | |||
| 43975ec231 | |||
| 9b13922fc2 | |||
| 031456629b | |||
| 2433ed1c9b | |||
| 9868d54320 | |||
| 747a3ed6e9 | |||
| 527e849ce2 | |||
| cfcd54ca19 | |||
| faed9cd66e | |||
| 897d0dbcbd | |||
| a12e991798 | |||
| e5b86c3578 | |||
| 07ff433134 | |||
| 21b3e0c8cb | |||
| cbdec236dd | |||
| 2509ccde1c | |||
| 7e7b33dba7 | |||
| 13e1e44626 | |||
| e634f23fc8 | |||
| 8554a8e0c5 | |||
| b80abffafc | |||
| 204f21699e | |||
| 0fd478fa3e | |||
| 7d7e47e972 | |||
| 92a33a408f | |||
| d18a54e9e6 | |||
| e6614a0705 | |||
| 4c491cf221 | |||
| 17434c84cf | |||
| 234fb2a0c6 | |||
| 00612f921d | |||
| 8b67015190 | |||
| 5a5176e21f | |||
| 8980282a02 | |||
| 2ca9edb1bc | |||
| 61d970cda4 | |||
| 16fd9cab67 | |||
| 8c7818a252 | |||
| 374779102a | |||
| 0ac854458a | |||
| 1cfaddf49d | |||
| 5ae69f5987 | |||
| c62e1d5457 | |||
| 5b8681b1af | |||
| e0dcade9ad | |||
| 1a6ab7f24b | |||
| 769844314c | |||
| e211604860 | |||
| 7ed711e8f0 | |||
| 196b276345 | |||
| 3c62c80ff1 | |||
| a031f1107a | |||
| 8f399bba3f | |||
| e354e877ea | |||
| f254b8cf8c | |||
| 814b06322a | |||
| 217063ef7b | |||
| c2f7883a5c | |||
| bd64c34787 | |||
| 7518d4391f | |||
| e67bd79c66 | |||
| 2fc6da53c1 | |||
| 250a98cf98 | |||
| f2926fa1eb | |||
| 5e2af4a740 | |||
| 41f2ca42cc | |||
| 7ef547b357 | |||
| 1a9c529e92 | |||
| 75d19bfe76 | |||
| 7f8f7376e0 | |||
| 7c49de9cba | |||
| 00ac9b6367 | |||
| 0e786f7040 | |||
| 03d363ba84 | |||
| 3f33519ec0 | |||
| cae03beb6d | |||
| e4c1e5aed0 | |||
| 5acdd67cba | |||
| 40dbac7a65 | |||
| 1b4ed02959 | |||
| a95e730cdb | |||
| d8c13159e1 | |||
| 5f951ca3ef | |||
| 338da72622 | |||
| 90debcdd70 | |||
| 3766ca86e8 | |||
| 59c8472628 | |||
| 293616e6b0 | |||
| f7305d58b1 | |||
| ba94f63705 | |||
| 06b2e0d14b | |||
| 80a5f44491 | |||
| aca0bde46d | |||
| e671811ad2 | |||
| 3140325493 | |||
| 6c0b879b30 | |||
| 0e0fb37dd7 | |||
| d2cacdc640 | |||
| e65fabf040 | |||
| 107b96e65c | |||
| 5d7ba51872 | |||
| 3037701a14 | |||
| 66f8377c79 | |||
| 86f81d92aa | |||
| 369437f2a1 | |||
| 4b8b80f1d4 | |||
| f839aef33a | |||
| eb87e30076 | |||
| 4302f91028 | |||
| b0af20b0d5 | |||
| 9b556cf4c4 | |||
| 7118219544 | |||
| 475600ea87 | |||
| 2139e0be05 | |||
| a43a0f77fb | |||
| 8a073e8c60 | |||
| 35640fcdfa | |||
| c62f73400a | |||
| c92cbd7e22 | |||
| c3b0d09e04 | |||
| c5a40fced3 | |||
| 9cc6ebabc1 | |||
| e89659fe71 | |||
| 3c1512028d | |||
| c7f80686de | |||
| 8a8386cfcb | |||
| e60165ee45 | |||
| bc6085adc7 | |||
| d413e2875c | |||
| 144986f48e | |||
| c9f1e34beb | |||
| 39f769b150 | |||
| 78180e376f | |||
| d51150102c | |||
| d5da16ad26 | |||
| 2b12e32fcf | |||
| a9b9661155 | |||
| f5f0cef275 | |||
| b756965511 | |||
| db900c4a42 | |||
| 72cb62085b | |||
| 5a42815850 | |||
| b9083a906a | |||
| 04be734c49 | |||
| 1ed6cf7517 | |||
| d6c4f97158 | |||
| 781704fa38 | |||
| 28f4d7d566 | |||
| 991778b2be | |||
| 9465dafd7d | |||
| 75c13a8801 | |||
| 8ae0f145f5 | |||
| 4d0e0e3afe | |||
| 7aeb874ded | |||
| ffc695f7b8 | |||
| 93cb621af3 | |||
| 3a34680196 | |||
| 2335a3130a | |||
| 0bc4b69f52 | |||
| 43c5c1276d | |||
| a3ebfd9bbd | |||
| af5b894e62 | |||
| c982066235 | |||
| 1f6c1522b6 | |||
| bae83ba78e | |||
| 0d0aeab4ee | |||
| 7fe91339ad | |||
| 44dea1d208 | |||
| 6d3be40022 | |||
| 07773f92a0 | |||
| dbc4a2b730 | |||
| df15a78aac | |||
| 61b517edfa | |||
| 082b342f65 | |||
| 9a536ee4b9 | |||
| 677f04cab2 | |||
| 3ddc35cddc | |||
| ae211226ef | |||
| 6662611347 | |||
| c4b988c632 | |||
| 2b1ee8cd5c | |||
| e8cfc2b91e | |||
| de54404ab7 | |||
| f8c3b64274 | 
| @ -1,5 +1,5 @@ | ||||
| [bumpversion] | ||||
| current_version = 2025.6.2 | ||||
| current_version = 2025.6.3 | ||||
| tag = True | ||||
| commit = True | ||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | ||||
|  | ||||
| @ -38,6 +38,8 @@ jobs: | ||||
|       # Needed for attestation | ||||
|       id-token: write | ||||
|       attestations: write | ||||
|       # Needed for checkout | ||||
|       contents: read | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|       - uses: docker/setup-qemu-action@v3.6.0 | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/workflows/ci-main-daily.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/ci-main-daily.yml
									
									
									
									
										vendored
									
									
								
							| @ -9,6 +9,7 @@ on: | ||||
|  | ||||
| jobs: | ||||
|   test-container: | ||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} | ||||
|     runs-on: ubuntu-latest | ||||
|     strategy: | ||||
|       fail-fast: false | ||||
|  | ||||
							
								
								
									
										4
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -247,11 +247,13 @@ jobs: | ||||
|       # Needed for attestation | ||||
|       id-token: write | ||||
|       attestations: write | ||||
|       # Needed for checkout | ||||
|       contents: read | ||||
|     needs: ci-core-mark | ||||
|     uses: ./.github/workflows/_reusable-docker-build.yaml | ||||
|     secrets: inherit | ||||
|     with: | ||||
|       image_name: ghcr.io/goauthentik/dev-server | ||||
|       image_name: ${{ github.repository == 'goauthentik/authentik-internal' && 'ghcr.io/goauthentik/internal-server' || 'ghcr.io/goauthentik/dev-server' }} | ||||
|       release: false | ||||
|   pr-comment: | ||||
|     needs: | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							| @ -59,6 +59,7 @@ jobs: | ||||
|         with: | ||||
|           jobs: ${{ toJSON(needs) }} | ||||
|   build-container: | ||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} | ||||
|     timeout-minutes: 120 | ||||
|     needs: | ||||
|       - ci-outpost-mark | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							| @ -63,6 +63,7 @@ jobs: | ||||
|         working-directory: website/ | ||||
|         run: npm run ${{ matrix.job }} | ||||
|   build-container: | ||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} | ||||
|     runs-on: ubuntu-latest | ||||
|     permissions: | ||||
|       # Needed to upload container images to ghcr.io | ||||
| @ -122,3 +123,4 @@ jobs: | ||||
|       - uses: re-actors/alls-green@release/v1 | ||||
|         with: | ||||
|           jobs: ${{ toJSON(needs) }} | ||||
|           allowed-skips: ${{ github.repository == 'goauthentik/authentik-internal' && 'build-container' || '[]' }} | ||||
|  | ||||
							
								
								
									
										21
									
								
								.github/workflows/repo-mirror-cleanup.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								.github/workflows/repo-mirror-cleanup.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | ||||
| name: "authentik-repo-mirror-cleanup" | ||||
|  | ||||
| on: | ||||
|   workflow_dispatch: | ||||
|  | ||||
| jobs: | ||||
|   to_internal: | ||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|         with: | ||||
|           fetch-depth: 0 | ||||
|       - if: ${{ env.MIRROR_KEY != '' }} | ||||
|         uses: BeryJu/repository-mirroring-action@5cf300935bc2e068f73ea69bcc411a8a997208eb | ||||
|         with: | ||||
|           target_repo_url: git@github.com:goauthentik/authentik-internal.git | ||||
|           ssh_private_key: ${{ secrets.GH_MIRROR_KEY }} | ||||
|           args: --tags --force --prune | ||||
|         env: | ||||
|           MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }} | ||||
							
								
								
									
										9
									
								
								.github/workflows/repo-mirror.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								.github/workflows/repo-mirror.yml
									
									
									
									
										vendored
									
									
								
							| @ -11,11 +11,10 @@ jobs: | ||||
|         with: | ||||
|           fetch-depth: 0 | ||||
|       - if: ${{ env.MIRROR_KEY != '' }} | ||||
|         uses: pixta-dev/repository-mirroring-action@v1 | ||||
|         uses: BeryJu/repository-mirroring-action@5cf300935bc2e068f73ea69bcc411a8a997208eb | ||||
|         with: | ||||
|           target_repo_url: | ||||
|             git@github.com:goauthentik/authentik-internal.git | ||||
|           ssh_private_key: | ||||
|             ${{ secrets.GH_MIRROR_KEY }} | ||||
|           target_repo_url: git@github.com:goauthentik/authentik-internal.git | ||||
|           ssh_private_key: ${{ secrets.GH_MIRROR_KEY }} | ||||
|           args: --tags --force | ||||
|         env: | ||||
|           MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }} | ||||
|  | ||||
| @ -16,6 +16,7 @@ env: | ||||
|  | ||||
| jobs: | ||||
|   compile: | ||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - id: generate_token | ||||
|  | ||||
							
								
								
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -100,9 +100,6 @@ ipython_config.py | ||||
| # pyenv | ||||
| .python-version | ||||
|  | ||||
| # celery beat schedule file | ||||
| celerybeat-schedule | ||||
|  | ||||
| # SageMath parsed files | ||||
| *.sage.py | ||||
|  | ||||
| @ -166,8 +163,6 @@ dmypy.json | ||||
|  | ||||
| # pyenv | ||||
|  | ||||
| # celery beat schedule file | ||||
|  | ||||
| # SageMath parsed files | ||||
|  | ||||
| # Environments | ||||
|  | ||||
| @ -75,7 +75,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \ | ||||
|     /bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" | ||||
|  | ||||
| # Stage 4: Download uv | ||||
| FROM ghcr.io/astral-sh/uv:0.7.14 AS uv | ||||
| FROM ghcr.io/astral-sh/uv:0.7.17 AS uv | ||||
| # Stage 5: Base python image | ||||
| FROM ghcr.io/goauthentik/fips-python:3.13.5-slim-bookworm-fips AS python-base | ||||
|  | ||||
| @ -122,6 +122,7 @@ ENV UV_NO_BINARY_PACKAGE="cryptography lxml python-kadmin-rs xmlsec" | ||||
|  | ||||
| RUN --mount=type=bind,target=pyproject.toml,src=pyproject.toml \ | ||||
|     --mount=type=bind,target=uv.lock,src=uv.lock \ | ||||
|     --mount=type=bind,target=packages,src=packages \ | ||||
|     --mount=type=cache,target=/root/.cache/uv \ | ||||
|     uv sync --frozen --no-install-project --no-dev | ||||
|  | ||||
| @ -167,6 +168,7 @@ COPY ./blueprints /blueprints | ||||
| COPY ./lifecycle/ /lifecycle | ||||
| COPY ./authentik/sources/kerberos/krb5.conf /etc/krb5.conf | ||||
| COPY --from=go-builder /go/authentik /bin/authentik | ||||
| COPY ./packages/ /ak-root/packages | ||||
| COPY --from=python-deps /ak-root/.venv /ak-root/.venv | ||||
| COPY --from=node-builder /work/web/dist/ /web/dist/ | ||||
| COPY --from=node-builder /work/web/authentik/ /web/authentik/ | ||||
|  | ||||
							
								
								
									
										8
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								Makefile
									
									
									
									
									
								
							| @ -6,7 +6,7 @@ PWD = $(shell pwd) | ||||
| UID = $(shell id -u) | ||||
| GID = $(shell id -g) | ||||
| NPM_VERSION = $(shell python -m scripts.generate_semver) | ||||
| PY_SOURCES = authentik tests scripts lifecycle .github | ||||
| PY_SOURCES = authentik packages tests scripts lifecycle .github | ||||
| DOCKER_IMAGE ?= "authentik:test" | ||||
|  | ||||
| GEN_API_TS = gen-ts-api | ||||
| @ -150,9 +150,9 @@ gen-client-ts: gen-clean-ts  ## Build and install the authentik API for Typescri | ||||
| 		--additional-properties=npmVersion=${NPM_VERSION} \ | ||||
| 		--git-repo-id authentik \ | ||||
| 		--git-user-id goauthentik | ||||
| 	mkdir -p web/node_modules/@goauthentik/api | ||||
| 	cd ${PWD}/${GEN_API_TS} && npm i | ||||
| 	\cp -rf ${PWD}/${GEN_API_TS}/* web/node_modules/@goauthentik/api | ||||
|  | ||||
| 	cd ${PWD}/${GEN_API_TS} && npm link | ||||
| 	cd ${PWD}/web && npm link @goauthentik/api | ||||
|  | ||||
| gen-client-py: gen-clean-py ## Build and install the authentik API for Python | ||||
| 	docker run \ | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
|  | ||||
| from os import environ | ||||
|  | ||||
| __version__ = "2025.6.2" | ||||
| __version__ = "2025.6.3" | ||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -41,7 +41,7 @@ class VersionSerializer(PassiveSerializer): | ||||
|             return __version__ | ||||
|         version_in_cache = cache.get(VERSION_CACHE_KEY) | ||||
|         if not version_in_cache:  # pragma: no cover | ||||
|             update_latest_version.delay() | ||||
|             update_latest_version.send() | ||||
|             return __version__ | ||||
|         return version_in_cache | ||||
|  | ||||
|  | ||||
| @ -1,57 +0,0 @@ | ||||
| """authentik administration overview""" | ||||
|  | ||||
| from socket import gethostname | ||||
|  | ||||
| from django.conf import settings | ||||
| from drf_spectacular.utils import extend_schema, inline_serializer | ||||
| from packaging.version import parse | ||||
| from rest_framework.fields import BooleanField, CharField | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
| from rest_framework.views import APIView | ||||
|  | ||||
| from authentik import get_full_version | ||||
| from authentik.rbac.permissions import HasPermission | ||||
| from authentik.root.celery import CELERY_APP | ||||
|  | ||||
|  | ||||
| class WorkerView(APIView): | ||||
|     """Get currently connected worker count.""" | ||||
|  | ||||
|     permission_classes = [HasPermission("authentik_rbac.view_system_info")] | ||||
|  | ||||
|     @extend_schema( | ||||
|         responses=inline_serializer( | ||||
|             "Worker", | ||||
|             fields={ | ||||
|                 "worker_id": CharField(), | ||||
|                 "version": CharField(), | ||||
|                 "version_matching": BooleanField(), | ||||
|             }, | ||||
|             many=True, | ||||
|         ) | ||||
|     ) | ||||
|     def get(self, request: Request) -> Response: | ||||
|         """Get currently connected worker count.""" | ||||
|         raw: list[dict[str, dict]] = CELERY_APP.control.ping(timeout=0.5) | ||||
|         our_version = parse(get_full_version()) | ||||
|         response = [] | ||||
|         for worker in raw: | ||||
|             key = list(worker.keys())[0] | ||||
|             version = worker[key].get("version") | ||||
|             version_matching = False | ||||
|             if version: | ||||
|                 version_matching = parse(version) == our_version | ||||
|             response.append( | ||||
|                 {"worker_id": key, "version": version, "version_matching": version_matching} | ||||
|             ) | ||||
|         # In debug we run with `task_always_eager`, so tasks are ran on the main process | ||||
|         if settings.DEBUG:  # pragma: no cover | ||||
|             response.append( | ||||
|                 { | ||||
|                     "worker_id": f"authentik-debug@{gethostname()}", | ||||
|                     "version": get_full_version(), | ||||
|                     "version_matching": True, | ||||
|                 } | ||||
|             ) | ||||
|         return Response(response) | ||||
| @ -3,6 +3,9 @@ | ||||
| from prometheus_client import Info | ||||
|  | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
| from authentik.tasks.schedules.lib import ScheduleSpec | ||||
|  | ||||
| PROM_INFO = Info("authentik_version", "Currently running authentik version") | ||||
|  | ||||
| @ -30,3 +33,15 @@ class AuthentikAdminConfig(ManagedAppConfig): | ||||
|             notification_version = notification.event.context["new_version"] | ||||
|             if LOCAL_VERSION >= parse(notification_version): | ||||
|                 notification.delete() | ||||
|  | ||||
|     @property | ||||
|     def global_schedule_specs(self) -> list[ScheduleSpec]: | ||||
|         from authentik.admin.tasks import update_latest_version | ||||
|  | ||||
|         return [ | ||||
|             ScheduleSpec( | ||||
|                 actor=update_latest_version, | ||||
|                 crontab=f"{fqdn_rand('admin_latest_version')} * * * *", | ||||
|                 paused=CONFIG.get_bool("disable_update_check"), | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
| @ -1,15 +0,0 @@ | ||||
| """authentik admin settings""" | ||||
|  | ||||
| from celery.schedules import crontab | ||||
| from django_tenants.utils import get_public_schema_name | ||||
|  | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
|  | ||||
| CELERY_BEAT_SCHEDULE = { | ||||
|     "admin_latest_version": { | ||||
|         "task": "authentik.admin.tasks.update_latest_version", | ||||
|         "schedule": crontab(minute=fqdn_rand("admin_latest_version"), hour="*"), | ||||
|         "tenant_schemas": [get_public_schema_name()], | ||||
|         "options": {"queue": "authentik_scheduled"}, | ||||
|     } | ||||
| } | ||||
| @ -1,35 +0,0 @@ | ||||
| """admin signals""" | ||||
|  | ||||
| from django.dispatch import receiver | ||||
| from packaging.version import parse | ||||
| from prometheus_client import Gauge | ||||
|  | ||||
| from authentik import get_full_version | ||||
| from authentik.root.celery import CELERY_APP | ||||
| from authentik.root.monitoring import monitoring_set | ||||
|  | ||||
| GAUGE_WORKERS = Gauge( | ||||
|     "authentik_admin_workers", | ||||
|     "Currently connected workers, their versions and if they are the same version as authentik", | ||||
|     ["version", "version_matched"], | ||||
| ) | ||||
|  | ||||
|  | ||||
| _version = parse(get_full_version()) | ||||
|  | ||||
|  | ||||
| @receiver(monitoring_set) | ||||
| def monitoring_set_workers(sender, **kwargs): | ||||
|     """Set worker gauge""" | ||||
|     raw: list[dict[str, dict]] = CELERY_APP.control.ping(timeout=0.5) | ||||
|     worker_version_count = {} | ||||
|     for worker in raw: | ||||
|         key = list(worker.keys())[0] | ||||
|         version = worker[key].get("version") | ||||
|         version_matching = False | ||||
|         if version: | ||||
|             version_matching = parse(version) == _version | ||||
|         worker_version_count.setdefault(version, {"count": 0, "matching": version_matching}) | ||||
|         worker_version_count[version]["count"] += 1 | ||||
|     for version, stats in worker_version_count.items(): | ||||
|         GAUGE_WORKERS.labels(version, stats["matching"]).set(stats["count"]) | ||||
| @ -2,6 +2,8 @@ | ||||
|  | ||||
| from django.core.cache import cache | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from django_dramatiq_postgres.middleware import CurrentTask | ||||
| from dramatiq import actor | ||||
| from packaging.version import parse | ||||
| from requests import RequestException | ||||
| from structlog.stdlib import get_logger | ||||
| @ -9,10 +11,9 @@ from structlog.stdlib import get_logger | ||||
| from authentik import __version__, get_build_hash | ||||
| from authentik.admin.apps import PROM_INFO | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.lib.utils.http import get_http_session | ||||
| from authentik.root.celery import CELERY_APP | ||||
| from authentik.tasks.models import Task | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| VERSION_NULL = "0.0.0" | ||||
| @ -32,13 +33,12 @@ def _set_prom_info(): | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(bind=True, base=SystemTask) | ||||
| @prefill_task | ||||
| def update_latest_version(self: SystemTask): | ||||
|     """Update latest version info""" | ||||
| @actor(description=_("Update latest version info.")) | ||||
| def update_latest_version(): | ||||
|     self: Task = CurrentTask.get_task() | ||||
|     if CONFIG.get_bool("disable_update_check"): | ||||
|         cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) | ||||
|         self.set_status(TaskStatus.WARNING, "Version check disabled.") | ||||
|         self.info("Version check disabled.") | ||||
|         return | ||||
|     try: | ||||
|         response = get_http_session().get( | ||||
| @ -48,7 +48,7 @@ def update_latest_version(self: SystemTask): | ||||
|         data = response.json() | ||||
|         upstream_version = data.get("stable", {}).get("version") | ||||
|         cache.set(VERSION_CACHE_KEY, upstream_version, VERSION_CACHE_TIMEOUT) | ||||
|         self.set_status(TaskStatus.SUCCESSFUL, "Successfully updated latest Version") | ||||
|         self.info("Successfully updated latest Version") | ||||
|         _set_prom_info() | ||||
|         # Check if upstream version is newer than what we're running, | ||||
|         # and if no event exists yet, create one. | ||||
| @ -71,7 +71,7 @@ def update_latest_version(self: SystemTask): | ||||
|             ).save() | ||||
|     except (RequestException, IndexError) as exc: | ||||
|         cache.set(VERSION_CACHE_KEY, VERSION_NULL, VERSION_CACHE_TIMEOUT) | ||||
|         self.set_error(exc) | ||||
|         raise exc | ||||
|  | ||||
|  | ||||
| _set_prom_info() | ||||
|  | ||||
| @ -29,13 +29,6 @@ class TestAdminAPI(TestCase): | ||||
|         body = loads(response.content) | ||||
|         self.assertEqual(body["version_current"], __version__) | ||||
|  | ||||
|     def test_workers(self): | ||||
|         """Test Workers API""" | ||||
|         response = self.client.get(reverse("authentik_api:admin_workers")) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         body = loads(response.content) | ||||
|         self.assertEqual(len(body), 0) | ||||
|  | ||||
|     def test_apps(self): | ||||
|         """Test apps API""" | ||||
|         response = self.client.get(reverse("authentik_api:apps-list")) | ||||
|  | ||||
| @ -30,7 +30,7 @@ class TestAdminTasks(TestCase): | ||||
|         """Test Update checker with valid response""" | ||||
|         with Mocker() as mocker, CONFIG.patch("disable_update_check", False): | ||||
|             mocker.get("https://version.goauthentik.io/version.json", json=RESPONSE_VALID) | ||||
|             update_latest_version.delay().get() | ||||
|             update_latest_version.send() | ||||
|             self.assertEqual(cache.get(VERSION_CACHE_KEY), "99999999.9999999") | ||||
|             self.assertTrue( | ||||
|                 Event.objects.filter( | ||||
| @ -40,7 +40,7 @@ class TestAdminTasks(TestCase): | ||||
|                 ).exists() | ||||
|             ) | ||||
|             # test that a consecutive check doesn't create a duplicate event | ||||
|             update_latest_version.delay().get() | ||||
|             update_latest_version.send() | ||||
|             self.assertEqual( | ||||
|                 len( | ||||
|                     Event.objects.filter( | ||||
| @ -56,7 +56,7 @@ class TestAdminTasks(TestCase): | ||||
|         """Test Update checker with invalid response""" | ||||
|         with Mocker() as mocker: | ||||
|             mocker.get("https://version.goauthentik.io/version.json", status_code=400) | ||||
|             update_latest_version.delay().get() | ||||
|             update_latest_version.send() | ||||
|             self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") | ||||
|             self.assertFalse( | ||||
|                 Event.objects.filter( | ||||
| @ -67,14 +67,15 @@ class TestAdminTasks(TestCase): | ||||
|     def test_version_disabled(self): | ||||
|         """Test Update checker while its disabled""" | ||||
|         with CONFIG.patch("disable_update_check", True): | ||||
|             update_latest_version.delay().get() | ||||
|             update_latest_version.send() | ||||
|             self.assertEqual(cache.get(VERSION_CACHE_KEY), "0.0.0") | ||||
|  | ||||
|     def test_clear_update_notifications(self): | ||||
|         """Test clear of previous notification""" | ||||
|         admin_config = apps.get_app_config("authentik_admin") | ||||
|         Event.objects.create( | ||||
|             action=EventAction.UPDATE_AVAILABLE, context={"new_version": "99999999.9999999.9999999"} | ||||
|             action=EventAction.UPDATE_AVAILABLE, | ||||
|             context={"new_version": "99999999.9999999.9999999"}, | ||||
|         ) | ||||
|         Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={"new_version": "1.1.1"}) | ||||
|         Event.objects.create(action=EventAction.UPDATE_AVAILABLE, context={}) | ||||
|  | ||||
| @ -6,13 +6,11 @@ from authentik.admin.api.meta import AppsViewSet, ModelViewSet | ||||
| from authentik.admin.api.system import SystemView | ||||
| from authentik.admin.api.version import VersionView | ||||
| from authentik.admin.api.version_history import VersionHistoryViewSet | ||||
| from authentik.admin.api.workers import WorkerView | ||||
|  | ||||
| api_urlpatterns = [ | ||||
|     ("admin/apps", AppsViewSet, "apps"), | ||||
|     ("admin/models", ModelViewSet, "models"), | ||||
|     path("admin/version/", VersionView.as_view(), name="admin_version"), | ||||
|     ("admin/version/history", VersionHistoryViewSet, "version_history"), | ||||
|     path("admin/workers/", WorkerView.as_view(), name="admin_workers"), | ||||
|     path("admin/system/", SystemView.as_view(), name="admin_system"), | ||||
| ] | ||||
|  | ||||
| @ -39,7 +39,7 @@ class BlueprintInstanceSerializer(ModelSerializer): | ||||
|         """Ensure the path (if set) specified is retrievable""" | ||||
|         if path == "" or path.startswith(OCI_PREFIX): | ||||
|             return path | ||||
|         files: list[dict] = blueprints_find_dict.delay().get() | ||||
|         files: list[dict] = blueprints_find_dict.send().get_result(block=True) | ||||
|         if path not in [file["path"] for file in files]: | ||||
|             raise ValidationError(_("Blueprint file does not exist")) | ||||
|         return path | ||||
| @ -115,7 +115,7 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet): | ||||
|     @action(detail=False, pagination_class=None, filter_backends=[]) | ||||
|     def available(self, request: Request) -> Response: | ||||
|         """Get blueprints""" | ||||
|         files: list[dict] = blueprints_find_dict.delay().get() | ||||
|         files: list[dict] = blueprints_find_dict.send().get_result(block=True) | ||||
|         return Response(files) | ||||
|  | ||||
|     @permission_required("authentik_blueprints.view_blueprintinstance") | ||||
| @ -129,5 +129,5 @@ class BlueprintInstanceViewSet(UsedByMixin, ModelViewSet): | ||||
|     def apply(self, request: Request, *args, **kwargs) -> Response: | ||||
|         """Apply a blueprint""" | ||||
|         blueprint = self.get_object() | ||||
|         apply_blueprint.delay(str(blueprint.pk)).get() | ||||
|         apply_blueprint.send_with_options(args=(blueprint.pk,), rel_obj=blueprint) | ||||
|         return self.retrieve(request, *args, **kwargs) | ||||
|  | ||||
| @ -6,9 +6,12 @@ from inspect import ismethod | ||||
|  | ||||
| from django.apps import AppConfig | ||||
| from django.db import DatabaseError, InternalError, ProgrammingError | ||||
| from dramatiq.broker import get_broker | ||||
| from structlog.stdlib import BoundLogger, get_logger | ||||
|  | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
| from authentik.root.signals import startup | ||||
| from authentik.tasks.schedules.lib import ScheduleSpec | ||||
|  | ||||
|  | ||||
| class ManagedAppConfig(AppConfig): | ||||
| @ -34,7 +37,7 @@ class ManagedAppConfig(AppConfig): | ||||
|  | ||||
|     def import_related(self): | ||||
|         """Automatically import related modules which rely on just being imported | ||||
|         to register themselves (mainly django signals and celery tasks)""" | ||||
|         to register themselves (mainly django signals and tasks)""" | ||||
|  | ||||
|         def import_relative(rel_module: str): | ||||
|             try: | ||||
| @ -80,6 +83,16 @@ class ManagedAppConfig(AppConfig): | ||||
|         func._authentik_managed_reconcile = ManagedAppConfig.RECONCILE_GLOBAL_CATEGORY | ||||
|         return func | ||||
|  | ||||
|     @property | ||||
|     def tenant_schedule_specs(self) -> list[ScheduleSpec]: | ||||
|         """Get a list of schedule specs that must exist in each tenant""" | ||||
|         return [] | ||||
|  | ||||
|     @property | ||||
|     def global_schedule_specs(self) -> list[ScheduleSpec]: | ||||
|         """Get a list of schedule specs that must exist in the default tenant""" | ||||
|         return [] | ||||
|  | ||||
|     def _reconcile_tenant(self) -> None: | ||||
|         """reconcile ourselves for tenanted methods""" | ||||
|         from authentik.tenants.models import Tenant | ||||
| @ -100,8 +113,12 @@ class ManagedAppConfig(AppConfig): | ||||
|         """ | ||||
|         from django_tenants.utils import get_public_schema_name, schema_context | ||||
|  | ||||
|         with schema_context(get_public_schema_name()): | ||||
|             self._reconcile(self.RECONCILE_GLOBAL_CATEGORY) | ||||
|         try: | ||||
|             with schema_context(get_public_schema_name()): | ||||
|                 self._reconcile(self.RECONCILE_GLOBAL_CATEGORY) | ||||
|         except (DatabaseError, ProgrammingError, InternalError) as exc: | ||||
|             self.logger.debug("Failed to access database to run reconcile", exc=exc) | ||||
|             return | ||||
|  | ||||
|  | ||||
| class AuthentikBlueprintsConfig(ManagedAppConfig): | ||||
| @ -112,19 +129,29 @@ class AuthentikBlueprintsConfig(ManagedAppConfig): | ||||
|     verbose_name = "authentik Blueprints" | ||||
|     default = True | ||||
|  | ||||
|     @ManagedAppConfig.reconcile_global | ||||
|     def load_blueprints_v1_tasks(self): | ||||
|         """Load v1 tasks""" | ||||
|         self.import_module("authentik.blueprints.v1.tasks") | ||||
|  | ||||
|     @ManagedAppConfig.reconcile_tenant | ||||
|     def blueprints_discovery(self): | ||||
|         """Run blueprint discovery""" | ||||
|         from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints | ||||
|  | ||||
|         blueprints_discovery.delay() | ||||
|         clear_failed_blueprints.delay() | ||||
|  | ||||
|     def import_models(self): | ||||
|         super().import_models() | ||||
|         self.import_module("authentik.blueprints.v1.meta.apply_blueprint") | ||||
|  | ||||
|     @ManagedAppConfig.reconcile_global | ||||
|     def tasks_middlewares(self): | ||||
|         from authentik.blueprints.v1.tasks import BlueprintWatcherMiddleware | ||||
|  | ||||
|         get_broker().add_middleware(BlueprintWatcherMiddleware()) | ||||
|  | ||||
|     @property | ||||
|     def tenant_schedule_specs(self) -> list[ScheduleSpec]: | ||||
|         from authentik.blueprints.v1.tasks import blueprints_discovery, clear_failed_blueprints | ||||
|  | ||||
|         return [ | ||||
|             ScheduleSpec( | ||||
|                 actor=blueprints_discovery, | ||||
|                 crontab=f"{fqdn_rand('blueprints_v1_discover')} * * * *", | ||||
|                 send_on_startup=True, | ||||
|             ), | ||||
|             ScheduleSpec( | ||||
|                 actor=clear_failed_blueprints, | ||||
|                 crontab=f"{fqdn_rand('blueprints_v1_cleanup')} * * * *", | ||||
|                 send_on_startup=True, | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
| @ -3,6 +3,7 @@ | ||||
| from pathlib import Path | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from django.contrib.contenttypes.fields import GenericRelation | ||||
| from django.contrib.postgres.fields import ArrayField | ||||
| from django.db import models | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| @ -71,6 +72,13 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): | ||||
|     enabled = models.BooleanField(default=True) | ||||
|     managed_models = ArrayField(models.TextField(), default=list) | ||||
|  | ||||
|     # Manual link to tasks instead of using TasksModel because of loop imports | ||||
|     tasks = GenericRelation( | ||||
|         "authentik_tasks.Task", | ||||
|         content_type_field="rel_obj_content_type", | ||||
|         object_id_field="rel_obj_id", | ||||
|     ) | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("Blueprint Instance") | ||||
|         verbose_name_plural = _("Blueprint Instances") | ||||
|  | ||||
| @ -1,18 +0,0 @@ | ||||
| """blueprint Settings""" | ||||
|  | ||||
| from celery.schedules import crontab | ||||
|  | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
|  | ||||
| CELERY_BEAT_SCHEDULE = { | ||||
|     "blueprints_v1_discover": { | ||||
|         "task": "authentik.blueprints.v1.tasks.blueprints_discovery", | ||||
|         "schedule": crontab(minute=fqdn_rand("blueprints_v1_discover"), hour="*"), | ||||
|         "options": {"queue": "authentik_scheduled"}, | ||||
|     }, | ||||
|     "blueprints_v1_cleanup": { | ||||
|         "task": "authentik.blueprints.v1.tasks.clear_failed_blueprints", | ||||
|         "schedule": crontab(minute=fqdn_rand("blueprints_v1_cleanup"), hour="*"), | ||||
|         "options": {"queue": "authentik_scheduled"}, | ||||
|     }, | ||||
| } | ||||
							
								
								
									
										2
									
								
								authentik/blueprints/tasks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								authentik/blueprints/tasks.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | ||||
| # Import all v1 tasks for auto task discovery | ||||
| from authentik.blueprints.v1.tasks import *  # noqa: F403 | ||||
| @ -5,7 +5,6 @@ from collections.abc import Callable | ||||
| from django.apps import apps | ||||
| from django.test import TestCase | ||||
|  | ||||
| from authentik.blueprints.v1.importer import is_model_allowed | ||||
| from authentik.lib.models import SerializerModel | ||||
| from authentik.providers.oauth2.models import RefreshToken | ||||
|  | ||||
| @ -22,10 +21,13 @@ def serializer_tester_factory(test_model: type[SerializerModel]) -> Callable: | ||||
|             return | ||||
|         model_class = test_model() | ||||
|         self.assertTrue(isinstance(model_class, SerializerModel)) | ||||
|         # Models that have subclasses don't have to have a serializer | ||||
|         if len(test_model.__subclasses__()) > 0: | ||||
|             return | ||||
|         self.assertIsNotNone(model_class.serializer) | ||||
|         if model_class.serializer.Meta().model == RefreshToken: | ||||
|             return | ||||
|         self.assertEqual(model_class.serializer.Meta().model, test_model) | ||||
|         self.assertTrue(issubclass(test_model, model_class.serializer.Meta().model)) | ||||
|  | ||||
|     return tester | ||||
|  | ||||
| @ -34,6 +36,6 @@ for app in apps.get_app_configs(): | ||||
|     if not app.label.startswith("authentik"): | ||||
|         continue | ||||
|     for model in app.get_models(): | ||||
|         if not is_model_allowed(model): | ||||
|         if not issubclass(model, SerializerModel): | ||||
|             continue | ||||
|         setattr(TestModels, f"test_{app.label}_{model.__name__}", serializer_tester_factory(model)) | ||||
|  | ||||
| @ -54,7 +54,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | ||||
|             file.seek(0) | ||||
|             file_hash = sha512(file.read().encode()).hexdigest() | ||||
|             file.flush() | ||||
|             blueprints_discovery() | ||||
|             blueprints_discovery.send() | ||||
|             instance = BlueprintInstance.objects.filter(name=blueprint_id).first() | ||||
|             self.assertEqual(instance.last_applied_hash, file_hash) | ||||
|             self.assertEqual( | ||||
| @ -82,7 +82,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | ||||
|                 ) | ||||
|             ) | ||||
|             file.flush() | ||||
|             blueprints_discovery() | ||||
|             blueprints_discovery.send() | ||||
|             blueprint = BlueprintInstance.objects.filter(name="foo").first() | ||||
|             self.assertEqual( | ||||
|                 blueprint.last_applied_hash, | ||||
| @ -107,7 +107,7 @@ class TestBlueprintsV1Tasks(TransactionTestCase): | ||||
|                 ) | ||||
|             ) | ||||
|             file.flush() | ||||
|             blueprints_discovery() | ||||
|             blueprints_discovery.send() | ||||
|             blueprint.refresh_from_db() | ||||
|             self.assertEqual( | ||||
|                 blueprint.last_applied_hash, | ||||
|  | ||||
| @ -57,7 +57,6 @@ from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import ( | ||||
|     EndpointDeviceConnection, | ||||
| ) | ||||
| from authentik.events.logs import LogEvent, capture_logs | ||||
| from authentik.events.models import SystemTask | ||||
| from authentik.events.utils import cleanse_dict | ||||
| from authentik.flows.models import FlowToken, Stage | ||||
| from authentik.lib.models import SerializerModel | ||||
| @ -77,6 +76,7 @@ from authentik.providers.scim.models import SCIMProviderGroup, SCIMProviderUser | ||||
| from authentik.rbac.models import Role | ||||
| from authentik.sources.scim.models import SCIMSourceGroup, SCIMSourceUser | ||||
| from authentik.stages.authenticator_webauthn.models import WebAuthnDeviceType | ||||
| from authentik.tasks.models import Task | ||||
| from authentik.tenants.models import Tenant | ||||
|  | ||||
| # Context set when the serializer is created in a blueprint context | ||||
| @ -118,7 +118,7 @@ def excluded_models() -> list[type[Model]]: | ||||
|         SCIMProviderGroup, | ||||
|         SCIMProviderUser, | ||||
|         Tenant, | ||||
|         SystemTask, | ||||
|         Task, | ||||
|         ConnectionToken, | ||||
|         AuthorizationCode, | ||||
|         AccessToken, | ||||
|  | ||||
| @ -44,7 +44,7 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer): | ||||
|             return MetaResult() | ||||
|         LOGGER.debug("Applying blueprint from meta model", blueprint=self.blueprint_instance) | ||||
|  | ||||
|         apply_blueprint(str(self.blueprint_instance.pk)) | ||||
|         apply_blueprint(self.blueprint_instance.pk) | ||||
|         return MetaResult() | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -4,12 +4,17 @@ from dataclasses import asdict, dataclass, field | ||||
| from hashlib import sha512 | ||||
| from pathlib import Path | ||||
| from sys import platform | ||||
| from uuid import UUID | ||||
|  | ||||
| from dacite.core import from_dict | ||||
| from django.conf import settings | ||||
| from django.db import DatabaseError, InternalError, ProgrammingError | ||||
| from django.utils.text import slugify | ||||
| from django.utils.timezone import now | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from django_dramatiq_postgres.middleware import CurrentTask, CurrentTaskNotFound | ||||
| from dramatiq.actor import actor | ||||
| from dramatiq.middleware import Middleware | ||||
| from structlog.stdlib import get_logger | ||||
| from watchdog.events import ( | ||||
|     FileCreatedEvent, | ||||
| @ -31,15 +36,13 @@ from authentik.blueprints.v1.importer import Importer | ||||
| from authentik.blueprints.v1.labels import LABEL_AUTHENTIK_INSTANTIATE | ||||
| from authentik.blueprints.v1.oci import OCI_PREFIX | ||||
| from authentik.events.logs import capture_logs | ||||
| from authentik.events.models import TaskStatus | ||||
| from authentik.events.system_tasks import SystemTask, prefill_task | ||||
| from authentik.events.utils import sanitize_dict | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.root.celery import CELERY_APP | ||||
| from authentik.tasks.models import Task | ||||
| from authentik.tasks.schedules.models import Schedule | ||||
| from authentik.tenants.models import Tenant | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| _file_watcher_started = False | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| @ -53,22 +56,21 @@ class BlueprintFile: | ||||
|     meta: BlueprintMetadata | None = field(default=None) | ||||
|  | ||||
|  | ||||
| def start_blueprint_watcher(): | ||||
|     """Start blueprint watcher, if it's not running already.""" | ||||
|     # This function might be called twice since it's called on celery startup | ||||
| class BlueprintWatcherMiddleware(Middleware): | ||||
|     def start_blueprint_watcher(self): | ||||
|         """Start blueprint watcher""" | ||||
|         observer = Observer() | ||||
|         kwargs = {} | ||||
|         if platform.startswith("linux"): | ||||
|             kwargs["event_filter"] = (FileCreatedEvent, FileModifiedEvent) | ||||
|         observer.schedule( | ||||
|             BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs | ||||
|         ) | ||||
|         observer.start() | ||||
|  | ||||
|     global _file_watcher_started  # noqa: PLW0603 | ||||
|     if _file_watcher_started: | ||||
|         return | ||||
|     observer = Observer() | ||||
|     kwargs = {} | ||||
|     if platform.startswith("linux"): | ||||
|         kwargs["event_filter"] = (FileCreatedEvent, FileModifiedEvent) | ||||
|     observer.schedule( | ||||
|         BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs | ||||
|     ) | ||||
|     observer.start() | ||||
|     _file_watcher_started = True | ||||
|     def after_worker_boot(self, broker, worker): | ||||
|         if not settings.TEST: | ||||
|             self.start_blueprint_watcher() | ||||
|  | ||||
|  | ||||
| class BlueprintEventHandler(FileSystemEventHandler): | ||||
| @ -92,7 +94,7 @@ class BlueprintEventHandler(FileSystemEventHandler): | ||||
|         LOGGER.debug("new blueprint file created, starting discovery") | ||||
|         for tenant in Tenant.objects.filter(ready=True): | ||||
|             with tenant: | ||||
|                 blueprints_discovery.delay() | ||||
|                 Schedule.dispatch_by_actor(blueprints_discovery) | ||||
|  | ||||
|     def on_modified(self, event: FileSystemEvent): | ||||
|         """Process file modification""" | ||||
| @ -103,14 +105,14 @@ class BlueprintEventHandler(FileSystemEventHandler): | ||||
|             with tenant: | ||||
|                 for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True): | ||||
|                     LOGGER.debug("modified blueprint file, starting apply", instance=instance) | ||||
|                     apply_blueprint.delay(instance.pk.hex) | ||||
|                     apply_blueprint.send_with_options(args=(instance.pk,), rel_obj=instance) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task( | ||||
| @actor( | ||||
|     description=_("Find blueprints as `blueprints_find` does, but return a safe dict."), | ||||
|     throws=(DatabaseError, ProgrammingError, InternalError), | ||||
| ) | ||||
| def blueprints_find_dict(): | ||||
|     """Find blueprints as `blueprints_find` does, but return a safe dict""" | ||||
|     blueprints = [] | ||||
|     for blueprint in blueprints_find(): | ||||
|         blueprints.append(sanitize_dict(asdict(blueprint))) | ||||
| @ -146,21 +148,19 @@ def blueprints_find() -> list[BlueprintFile]: | ||||
|     return blueprints | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task( | ||||
|     throws=(DatabaseError, ProgrammingError, InternalError), base=SystemTask, bind=True | ||||
| @actor( | ||||
|     description=_("Find blueprints and check if they need to be created in the database."), | ||||
|     throws=(DatabaseError, ProgrammingError, InternalError), | ||||
| ) | ||||
| @prefill_task | ||||
| def blueprints_discovery(self: SystemTask, path: str | None = None): | ||||
|     """Find blueprints and check if they need to be created in the database""" | ||||
| def blueprints_discovery(path: str | None = None): | ||||
|     self: Task = CurrentTask.get_task() | ||||
|     count = 0 | ||||
|     for blueprint in blueprints_find(): | ||||
|         if path and blueprint.path != path: | ||||
|             continue | ||||
|         check_blueprint_v1_file(blueprint) | ||||
|         count += 1 | ||||
|     self.set_status( | ||||
|         TaskStatus.SUCCESSFUL, _("Successfully imported {count} files.".format(count=count)) | ||||
|     ) | ||||
|     self.info(f"Successfully imported {count} files.") | ||||
|  | ||||
|  | ||||
| def check_blueprint_v1_file(blueprint: BlueprintFile): | ||||
| @ -187,22 +187,26 @@ def check_blueprint_v1_file(blueprint: BlueprintFile): | ||||
|         ) | ||||
|     if instance.last_applied_hash != blueprint.hash: | ||||
|         LOGGER.info("Applying blueprint due to changed file", instance=instance, path=instance.path) | ||||
|         apply_blueprint.delay(str(instance.pk)) | ||||
|         apply_blueprint.send_with_options(args=(instance.pk,), rel_obj=instance) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task( | ||||
|     bind=True, | ||||
|     base=SystemTask, | ||||
| ) | ||||
| def apply_blueprint(self: SystemTask, instance_pk: str): | ||||
|     """Apply single blueprint""" | ||||
|     self.save_on_success = False | ||||
| @actor(description=_("Apply single blueprint.")) | ||||
| def apply_blueprint(instance_pk: UUID): | ||||
|     try: | ||||
|         self: Task = CurrentTask.get_task() | ||||
|     except CurrentTaskNotFound: | ||||
|         self = Task() | ||||
|     self.set_uid(str(instance_pk)) | ||||
|     instance: BlueprintInstance | None = None | ||||
|     try: | ||||
|         instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() | ||||
|         if not instance or not instance.enabled: | ||||
|         if not instance: | ||||
|             self.warning(f"Could not find blueprint {instance_pk}, skipping") | ||||
|             return | ||||
|         self.set_uid(slugify(instance.name)) | ||||
|         if not instance.enabled: | ||||
|             self.info(f"Blueprint {instance.name} is disabled, skipping") | ||||
|             return | ||||
|         blueprint_content = instance.retrieve() | ||||
|         file_hash = sha512(blueprint_content.encode()).hexdigest() | ||||
|         importer = Importer.from_string(blueprint_content, instance.context) | ||||
| @ -212,19 +216,18 @@ def apply_blueprint(self: SystemTask, instance_pk: str): | ||||
|         if not valid: | ||||
|             instance.status = BlueprintInstanceStatus.ERROR | ||||
|             instance.save() | ||||
|             self.set_status(TaskStatus.ERROR, *logs) | ||||
|             self.logs(logs) | ||||
|             return | ||||
|         with capture_logs() as logs: | ||||
|             applied = importer.apply() | ||||
|             if not applied: | ||||
|                 instance.status = BlueprintInstanceStatus.ERROR | ||||
|                 instance.save() | ||||
|                 self.set_status(TaskStatus.ERROR, *logs) | ||||
|                 self.logs(logs) | ||||
|                 return | ||||
|         instance.status = BlueprintInstanceStatus.SUCCESSFUL | ||||
|         instance.last_applied_hash = file_hash | ||||
|         instance.last_applied = now() | ||||
|         self.set_status(TaskStatus.SUCCESSFUL) | ||||
|     except ( | ||||
|         OSError, | ||||
|         DatabaseError, | ||||
| @ -235,15 +238,14 @@ def apply_blueprint(self: SystemTask, instance_pk: str): | ||||
|     ) as exc: | ||||
|         if instance: | ||||
|             instance.status = BlueprintInstanceStatus.ERROR | ||||
|         self.set_error(exc) | ||||
|         self.error(exc) | ||||
|     finally: | ||||
|         if instance: | ||||
|             instance.save() | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task() | ||||
| @actor(description=_("Remove blueprints which couldn't be fetched.")) | ||||
| def clear_failed_blueprints(): | ||||
|     """Remove blueprints which couldn't be fetched""" | ||||
|     # Exclude OCI blueprints as those might be temporarily unavailable | ||||
|     for blueprint in BlueprintInstance.objects.exclude(path__startswith=OCI_PREFIX): | ||||
|         try: | ||||
|  | ||||
| @ -9,6 +9,7 @@ class AuthentikBrandsConfig(ManagedAppConfig): | ||||
|     name = "authentik.brands" | ||||
|     label = "authentik_brands" | ||||
|     verbose_name = "authentik Brands" | ||||
|     default = True | ||||
|     mountpoints = { | ||||
|         "authentik.brands.urls_root": "", | ||||
|     } | ||||
|  | ||||
| @ -1,8 +1,7 @@ | ||||
| """authentik core app config""" | ||||
|  | ||||
| from django.conf import settings | ||||
|  | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
| from authentik.tasks.schedules.lib import ScheduleSpec | ||||
|  | ||||
|  | ||||
| class AuthentikCoreConfig(ManagedAppConfig): | ||||
| @ -14,14 +13,6 @@ class AuthentikCoreConfig(ManagedAppConfig): | ||||
|     mountpoint = "" | ||||
|     default = True | ||||
|  | ||||
|     @ManagedAppConfig.reconcile_global | ||||
|     def debug_worker_hook(self): | ||||
|         """Dispatch startup tasks inline when debugging""" | ||||
|         if settings.DEBUG: | ||||
|             from authentik.root.celery import worker_ready_hook | ||||
|  | ||||
|             worker_ready_hook() | ||||
|  | ||||
|     @ManagedAppConfig.reconcile_tenant | ||||
|     def source_inbuilt(self): | ||||
|         """Reconcile inbuilt source""" | ||||
| @ -34,3 +25,18 @@ class AuthentikCoreConfig(ManagedAppConfig): | ||||
|             }, | ||||
|             managed=Source.MANAGED_INBUILT, | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def tenant_schedule_specs(self) -> list[ScheduleSpec]: | ||||
|         from authentik.core.tasks import clean_expired_models, clean_temporary_users | ||||
|  | ||||
|         return [ | ||||
|             ScheduleSpec( | ||||
|                 actor=clean_expired_models, | ||||
|                 crontab="2-59/5 * * * *", | ||||
|             ), | ||||
|             ScheduleSpec( | ||||
|                 actor=clean_temporary_users, | ||||
|                 crontab="9-59/5 * * * *", | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
| @ -1,21 +0,0 @@ | ||||
| """Run bootstrap tasks""" | ||||
|  | ||||
| from django.core.management.base import BaseCommand | ||||
| from django_tenants.utils import get_public_schema_name | ||||
|  | ||||
| from authentik.root.celery import _get_startup_tasks_all_tenants, _get_startup_tasks_default_tenant | ||||
| from authentik.tenants.models import Tenant | ||||
|  | ||||
|  | ||||
| class Command(BaseCommand): | ||||
|     """Run bootstrap tasks to ensure certain objects are created""" | ||||
|  | ||||
|     def handle(self, **options): | ||||
|         for task in _get_startup_tasks_default_tenant(): | ||||
|             with Tenant.objects.get(schema_name=get_public_schema_name()): | ||||
|                 task() | ||||
|  | ||||
|         for task in _get_startup_tasks_all_tenants(): | ||||
|             for tenant in Tenant.objects.filter(ready=True): | ||||
|                 with tenant: | ||||
|                     task() | ||||
| @ -1,47 +0,0 @@ | ||||
| """Run worker""" | ||||
|  | ||||
| from sys import exit as sysexit | ||||
| from tempfile import tempdir | ||||
|  | ||||
| from celery.apps.worker import Worker | ||||
| from django.core.management.base import BaseCommand | ||||
| from django.db import close_old_connections | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.lib.debug import start_debug_server | ||||
| from authentik.root.celery import CELERY_APP | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| class Command(BaseCommand): | ||||
|     """Run worker""" | ||||
|  | ||||
|     def add_arguments(self, parser): | ||||
|         parser.add_argument( | ||||
|             "-b", | ||||
|             "--beat", | ||||
|             action="store_false", | ||||
|             help="When set, this worker will _not_ run Beat (scheduled) tasks", | ||||
|         ) | ||||
|  | ||||
|     def handle(self, **options): | ||||
|         LOGGER.debug("Celery options", **options) | ||||
|         close_old_connections() | ||||
|         start_debug_server() | ||||
|         worker: Worker = CELERY_APP.Worker( | ||||
|             no_color=False, | ||||
|             quiet=True, | ||||
|             optimization="fair", | ||||
|             autoscale=(CONFIG.get_int("worker.concurrency"), 1), | ||||
|             task_events=True, | ||||
|             beat=options.get("beat", True), | ||||
|             schedule_filename=f"{tempdir}/celerybeat-schedule", | ||||
|             queues=["authentik", "authentik_scheduled", "authentik_events"], | ||||
|         ) | ||||
|         for task in CELERY_APP.tasks: | ||||
|             LOGGER.debug("Registered task", task=task) | ||||
|  | ||||
|         worker.start() | ||||
|         sysexit(worker.exitcode) | ||||
| @ -1082,6 +1082,12 @@ class AuthenticatedSession(SerializerModel): | ||||
|  | ||||
|     user = models.ForeignKey(User, on_delete=models.CASCADE) | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> type[Serializer]: | ||||
|         from authentik.core.api.authenticated_sessions import AuthenticatedSessionSerializer | ||||
|  | ||||
|         return AuthenticatedSessionSerializer | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("Authenticated Session") | ||||
|         verbose_name_plural = _("Authenticated Sessions") | ||||
|  | ||||
| @ -3,6 +3,9 @@ | ||||
| from datetime import datetime, timedelta | ||||
|  | ||||
| from django.utils.timezone import now | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from django_dramatiq_postgres.middleware import CurrentTask | ||||
| from dramatiq.actor import actor | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.models import ( | ||||
| @ -11,17 +14,14 @@ from authentik.core.models import ( | ||||
|     ExpiringModel, | ||||
|     User, | ||||
| ) | ||||
| from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task | ||||
| from authentik.root.celery import CELERY_APP | ||||
| from authentik.tasks.models import Task | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(bind=True, base=SystemTask) | ||||
| @prefill_task | ||||
| def clean_expired_models(self: SystemTask): | ||||
|     """Remove expired objects""" | ||||
|     messages = [] | ||||
| @actor(description=_("Remove expired objects.")) | ||||
| def clean_expired_models(): | ||||
|     self: Task = CurrentTask.get_task() | ||||
|     for cls in ExpiringModel.__subclasses__(): | ||||
|         cls: ExpiringModel | ||||
|         objects = ( | ||||
| @ -31,16 +31,13 @@ def clean_expired_models(self: SystemTask): | ||||
|         for obj in objects: | ||||
|             obj.expire_action() | ||||
|         LOGGER.debug("Expired models", model=cls, amount=amount) | ||||
|         messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}") | ||||
|     self.set_status(TaskStatus.SUCCESSFUL, *messages) | ||||
|         self.info(f"Expired {amount} {cls._meta.verbose_name_plural}") | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(bind=True, base=SystemTask) | ||||
| @prefill_task | ||||
| def clean_temporary_users(self: SystemTask): | ||||
|     """Remove temporary users created by SAML Sources""" | ||||
| @actor(description=_("Remove temporary users created by SAML Sources.")) | ||||
| def clean_temporary_users(): | ||||
|     self: Task = CurrentTask.get_task() | ||||
|     _now = datetime.now() | ||||
|     messages = [] | ||||
|     deleted_users = 0 | ||||
|     for user in User.objects.filter(**{f"attributes__{USER_ATTRIBUTE_GENERATED}": True}): | ||||
|         if not user.attributes.get(USER_ATTRIBUTE_EXPIRES): | ||||
| @ -52,5 +49,4 @@ def clean_temporary_users(self: SystemTask): | ||||
|             LOGGER.debug("User is expired and will be deleted.", user=user, delta=delta) | ||||
|             user.delete() | ||||
|             deleted_users += 1 | ||||
|     messages.append(f"Successfully deleted {deleted_users} users.") | ||||
|     self.set_status(TaskStatus.SUCCESSFUL, *messages) | ||||
|     self.info(f"Successfully deleted {deleted_users} users.") | ||||
|  | ||||
| @ -36,7 +36,7 @@ class TestTasks(APITestCase): | ||||
|             expires=now(), user=get_anonymous_user(), intent=TokenIntents.INTENT_API | ||||
|         ) | ||||
|         key = token.key | ||||
|         clean_expired_models.delay().get() | ||||
|         clean_expired_models.send() | ||||
|         token.refresh_from_db() | ||||
|         self.assertNotEqual(key, token.key) | ||||
|  | ||||
| @ -50,5 +50,5 @@ class TestTasks(APITestCase): | ||||
|                 USER_ATTRIBUTE_EXPIRES: mktime(now().timetuple()), | ||||
|             }, | ||||
|         ) | ||||
|         clean_temporary_users.delay().get() | ||||
|         clean_temporary_users.send() | ||||
|         self.assertFalse(User.objects.filter(username=username)) | ||||
|  | ||||
| @ -4,6 +4,8 @@ from datetime import UTC, datetime | ||||
|  | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
| from authentik.tasks.schedules.lib import ScheduleSpec | ||||
|  | ||||
| MANAGED_KEY = "goauthentik.io/crypto/jwt-managed" | ||||
|  | ||||
| @ -67,3 +69,14 @@ class AuthentikCryptoConfig(ManagedAppConfig): | ||||
|                 "key_data": builder.private_key, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def tenant_schedule_specs(self) -> list[ScheduleSpec]: | ||||
|         from authentik.crypto.tasks import certificate_discovery | ||||
|  | ||||
|         return [ | ||||
|             ScheduleSpec( | ||||
|                 actor=certificate_discovery, | ||||
|                 crontab=f"{fqdn_rand('crypto_certificate_discovery')} * * * *", | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
| @ -1,13 +0,0 @@ | ||||
| """Crypto task Settings""" | ||||
|  | ||||
| from celery.schedules import crontab | ||||
|  | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
|  | ||||
| CELERY_BEAT_SCHEDULE = { | ||||
|     "crypto_certificate_discovery": { | ||||
|         "task": "authentik.crypto.tasks.certificate_discovery", | ||||
|         "schedule": crontab(minute=fqdn_rand("crypto_certificate_discovery"), hour="*"), | ||||
|         "options": {"queue": "authentik_scheduled"}, | ||||
|     }, | ||||
| } | ||||
| @ -7,13 +7,13 @@ from cryptography.hazmat.backends import default_backend | ||||
| from cryptography.hazmat.primitives.serialization import load_pem_private_key | ||||
| from cryptography.x509.base import load_pem_x509_certificate | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from django_dramatiq_postgres.middleware import CurrentTask | ||||
| from dramatiq.actor import actor | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.events.models import TaskStatus | ||||
| from authentik.events.system_tasks import SystemTask, prefill_task | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.root.celery import CELERY_APP | ||||
| from authentik.tasks.models import Task | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
| @ -36,10 +36,9 @@ def ensure_certificate_valid(body: str): | ||||
|     return body | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(bind=True, base=SystemTask) | ||||
| @prefill_task | ||||
| def certificate_discovery(self: SystemTask): | ||||
|     """Discover, import and update certificates from the filesystem""" | ||||
| @actor(description=_("Discover, import and update certificates from the filesystem.")) | ||||
| def certificate_discovery(): | ||||
|     self: Task = CurrentTask.get_task() | ||||
|     certs = {} | ||||
|     private_keys = {} | ||||
|     discovered = 0 | ||||
| @ -84,6 +83,4 @@ def certificate_discovery(self: SystemTask): | ||||
|                 dirty = True | ||||
|         if dirty: | ||||
|             cert.save() | ||||
|     self.set_status( | ||||
|         TaskStatus.SUCCESSFUL, _("Successfully imported {count} files.".format(count=discovered)) | ||||
|     ) | ||||
|     self.info(f"Successfully imported {discovered} files.") | ||||
|  | ||||
| @ -338,7 +338,7 @@ class TestCrypto(APITestCase): | ||||
|             with open(f"{temp_dir}/foo.bar/privkey.pem", "w+", encoding="utf-8") as _key: | ||||
|                 _key.write(builder.private_key) | ||||
|             with CONFIG.patch("cert_discovery_dir", temp_dir): | ||||
|                 certificate_discovery() | ||||
|                 certificate_discovery.send() | ||||
|         keypair: CertificateKeyPair = CertificateKeyPair.objects.filter( | ||||
|             managed=MANAGED_DISCOVERED % "foo" | ||||
|         ).first() | ||||
|  | ||||
| @ -3,6 +3,8 @@ | ||||
| from django.conf import settings | ||||
|  | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
| from authentik.tasks.schedules.lib import ScheduleSpec | ||||
|  | ||||
|  | ||||
| class EnterpriseConfig(ManagedAppConfig): | ||||
| @ -26,3 +28,14 @@ class AuthentikEnterpriseConfig(EnterpriseConfig): | ||||
|         from authentik.enterprise.license import LicenseKey | ||||
|  | ||||
|         return LicenseKey.cached_summary().status.is_valid | ||||
|  | ||||
|     @property | ||||
|     def tenant_schedule_specs(self) -> list[ScheduleSpec]: | ||||
|         from authentik.enterprise.tasks import enterprise_update_usage | ||||
|  | ||||
|         return [ | ||||
|             ScheduleSpec( | ||||
|                 actor=enterprise_update_usage, | ||||
|                 crontab=f"{fqdn_rand('enterprise_update_usage')} */2 * * *", | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
| @ -1,6 +1,8 @@ | ||||
| """authentik Unique Password policy app config""" | ||||
|  | ||||
| from authentik.enterprise.apps import EnterpriseConfig | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
| from authentik.tasks.schedules.lib import ScheduleSpec | ||||
|  | ||||
|  | ||||
| class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig): | ||||
| @ -8,3 +10,21 @@ class AuthentikEnterprisePoliciesUniquePasswordConfig(EnterpriseConfig): | ||||
|     label = "authentik_policies_unique_password" | ||||
|     verbose_name = "authentik Enterprise.Policies.Unique Password" | ||||
|     default = True | ||||
|  | ||||
|     @property | ||||
|     def tenant_schedule_specs(self) -> list[ScheduleSpec]: | ||||
|         from authentik.enterprise.policies.unique_password.tasks import ( | ||||
|             check_and_purge_password_history, | ||||
|             trim_password_histories, | ||||
|         ) | ||||
|  | ||||
|         return [ | ||||
|             ScheduleSpec( | ||||
|                 actor=trim_password_histories, | ||||
|                 crontab=f"{fqdn_rand('policies_unique_password_trim')} */12 * * *", | ||||
|             ), | ||||
|             ScheduleSpec( | ||||
|                 actor=check_and_purge_password_history, | ||||
|                 crontab=f"{fqdn_rand('policies_unique_password_purge')} */24 * * *", | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
| @ -1,20 +0,0 @@ | ||||
| """Unique Password Policy settings""" | ||||
|  | ||||
| from celery.schedules import crontab | ||||
|  | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
|  | ||||
| CELERY_BEAT_SCHEDULE = { | ||||
|     "policies_unique_password_trim_history": { | ||||
|         "task": "authentik.enterprise.policies.unique_password.tasks.trim_password_histories", | ||||
|         "schedule": crontab(minute=fqdn_rand("policies_unique_password_trim"), hour="*/12"), | ||||
|         "options": {"queue": "authentik_scheduled"}, | ||||
|     }, | ||||
|     "policies_unique_password_check_purge": { | ||||
|         "task": ( | ||||
|             "authentik.enterprise.policies.unique_password.tasks.check_and_purge_password_history" | ||||
|         ), | ||||
|         "schedule": crontab(minute=fqdn_rand("policies_unique_password_purge"), hour="*/24"), | ||||
|         "options": {"queue": "authentik_scheduled"}, | ||||
|     }, | ||||
| } | ||||
| @ -1,35 +1,37 @@ | ||||
| from django.db.models.aggregates import Count | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from django_dramatiq_postgres.middleware import CurrentTask | ||||
| from dramatiq.actor import actor | ||||
| from structlog import get_logger | ||||
|  | ||||
| from authentik.enterprise.policies.unique_password.models import ( | ||||
|     UniquePasswordPolicy, | ||||
|     UserPasswordHistory, | ||||
| ) | ||||
| from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task | ||||
| from authentik.root.celery import CELERY_APP | ||||
| from authentik.tasks.models import Task | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(bind=True, base=SystemTask) | ||||
| @prefill_task | ||||
| def check_and_purge_password_history(self: SystemTask): | ||||
|     """Check if any UniquePasswordPolicy exists, and if not, purge the password history table. | ||||
|     This is run on a schedule instead of being triggered by policy binding deletion. | ||||
|     """ | ||||
| @actor( | ||||
|     description=_( | ||||
|         "Check if any UniquePasswordPolicy exists, and if not, purge the password history table." | ||||
|     ) | ||||
| ) | ||||
| def check_and_purge_password_history(): | ||||
|     self: Task = CurrentTask.get_task() | ||||
|  | ||||
|     if not UniquePasswordPolicy.objects.exists(): | ||||
|         UserPasswordHistory.objects.all().delete() | ||||
|         LOGGER.debug("Purged UserPasswordHistory table as no policies are in use") | ||||
|         self.set_status(TaskStatus.SUCCESSFUL, "Successfully purged UserPasswordHistory") | ||||
|         self.info("Successfully purged UserPasswordHistory") | ||||
|         return | ||||
|  | ||||
|     self.set_status( | ||||
|         TaskStatus.SUCCESSFUL, "Not purging password histories, a unique password policy exists" | ||||
|     ) | ||||
|     self.info("Not purging password histories, a unique password policy exists") | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(bind=True, base=SystemTask) | ||||
| def trim_password_histories(self: SystemTask): | ||||
| @actor(description=_("Remove user password history that are too old.")) | ||||
| def trim_password_histories(): | ||||
|     """Removes rows from UserPasswordHistory older than | ||||
|     the `n` most recent entries. | ||||
|  | ||||
| @ -37,6 +39,8 @@ def trim_password_histories(self: SystemTask): | ||||
|     UniquePasswordPolicy policies. | ||||
|     """ | ||||
|  | ||||
|     self: Task = CurrentTask.get_task() | ||||
|  | ||||
|     # No policy, we'll let the cleanup above do its thing | ||||
|     if not UniquePasswordPolicy.objects.exists(): | ||||
|         return | ||||
| @ -63,4 +67,4 @@ def trim_password_histories(self: SystemTask): | ||||
|  | ||||
|     num_deleted, _ = UserPasswordHistory.objects.exclude(pk__in=all_pks_to_keep).delete() | ||||
|     LOGGER.debug("Deleted stale password history records", count=num_deleted) | ||||
|     self.set_status(TaskStatus.SUCCESSFUL, f"Delete {num_deleted} stale password history records") | ||||
|     self.info(f"Delete {num_deleted} stale password history records") | ||||
|  | ||||
| @ -76,7 +76,7 @@ class TestCheckAndPurgePasswordHistory(TestCase): | ||||
|         self.assertTrue(UserPasswordHistory.objects.exists()) | ||||
|  | ||||
|         # Run the task - should purge since no policy is in use | ||||
|         check_and_purge_password_history() | ||||
|         check_and_purge_password_history.send() | ||||
|  | ||||
|         # Verify the table is empty | ||||
|         self.assertFalse(UserPasswordHistory.objects.exists()) | ||||
| @ -99,7 +99,7 @@ class TestCheckAndPurgePasswordHistory(TestCase): | ||||
|         self.assertTrue(UserPasswordHistory.objects.exists()) | ||||
|  | ||||
|         # Run the task - should NOT purge since a policy is in use | ||||
|         check_and_purge_password_history() | ||||
|         check_and_purge_password_history.send() | ||||
|  | ||||
|         # Verify the entries still exist | ||||
|         self.assertTrue(UserPasswordHistory.objects.exists()) | ||||
| @ -142,7 +142,7 @@ class TestTrimPasswordHistory(TestCase): | ||||
|             enabled=True, | ||||
|             order=0, | ||||
|         ) | ||||
|         trim_password_histories.delay() | ||||
|         trim_password_histories.send() | ||||
|         user_pwd_history_qs = UserPasswordHistory.objects.filter(user=self.user) | ||||
|         self.assertEqual(len(user_pwd_history_qs), 1) | ||||
|  | ||||
| @ -159,7 +159,7 @@ class TestTrimPasswordHistory(TestCase): | ||||
|             enabled=False, | ||||
|             order=0, | ||||
|         ) | ||||
|         trim_password_histories.delay() | ||||
|         trim_password_histories.send() | ||||
|         self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists()) | ||||
|  | ||||
|     def test_trim_password_history_fewer_records_than_maximum_is_no_op(self): | ||||
| @ -174,5 +174,5 @@ class TestTrimPasswordHistory(TestCase): | ||||
|             enabled=True, | ||||
|             order=0, | ||||
|         ) | ||||
|         trim_password_histories.delay() | ||||
|         trim_password_histories.send() | ||||
|         self.assertTrue(UserPasswordHistory.objects.filter(user=self.user).exists()) | ||||
|  | ||||
| @ -55,5 +55,5 @@ class GoogleWorkspaceProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixi | ||||
|     ] | ||||
|     search_fields = ["name"] | ||||
|     ordering = ["name"] | ||||
|     sync_single_task = google_workspace_sync | ||||
|     sync_task = google_workspace_sync | ||||
|     sync_objects_task = google_workspace_sync_objects | ||||
|  | ||||
| @ -7,6 +7,7 @@ from django.db import models | ||||
| from django.db.models import QuerySet | ||||
| from django.templatetags.static import static | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from dramatiq.actor import Actor | ||||
| from google.oauth2.service_account import Credentials | ||||
| from rest_framework.serializers import Serializer | ||||
|  | ||||
| @ -110,6 +111,12 @@ class GoogleWorkspaceProvider(OutgoingSyncProvider, BackchannelProvider): | ||||
|         help_text=_("Property mappings used for group creation/updating."), | ||||
|     ) | ||||
|  | ||||
|     @property | ||||
|     def sync_actor(self) -> Actor: | ||||
|         from authentik.enterprise.providers.google_workspace.tasks import google_workspace_sync | ||||
|  | ||||
|         return google_workspace_sync | ||||
|  | ||||
|     def client_for_model( | ||||
|         self, | ||||
|         model: type[User | Group | GoogleWorkspaceProviderUser | GoogleWorkspaceProviderGroup], | ||||
|  | ||||
| @ -1,13 +0,0 @@ | ||||
| """Google workspace provider task Settings""" | ||||
|  | ||||
| from celery.schedules import crontab | ||||
|  | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
|  | ||||
| CELERY_BEAT_SCHEDULE = { | ||||
|     "providers_google_workspace_sync": { | ||||
|         "task": "authentik.enterprise.providers.google_workspace.tasks.google_workspace_sync_all", | ||||
|         "schedule": crontab(minute=fqdn_rand("google_workspace_sync_all"), hour="*/4"), | ||||
|         "options": {"queue": "authentik_scheduled"}, | ||||
|     }, | ||||
| } | ||||
| @ -2,15 +2,13 @@ | ||||
|  | ||||
| from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider | ||||
| from authentik.enterprise.providers.google_workspace.tasks import ( | ||||
|     google_workspace_sync, | ||||
|     google_workspace_sync_direct, | ||||
|     google_workspace_sync_m2m, | ||||
|     google_workspace_sync_direct_dispatch, | ||||
|     google_workspace_sync_m2m_dispatch, | ||||
| ) | ||||
| from authentik.lib.sync.outgoing.signals import register_signals | ||||
|  | ||||
| register_signals( | ||||
|     GoogleWorkspaceProvider, | ||||
|     task_sync_single=google_workspace_sync, | ||||
|     task_sync_direct=google_workspace_sync_direct, | ||||
|     task_sync_m2m=google_workspace_sync_m2m, | ||||
|     task_sync_direct_dispatch=google_workspace_sync_direct_dispatch, | ||||
|     task_sync_m2m_dispatch=google_workspace_sync_m2m_dispatch, | ||||
| ) | ||||
|  | ||||
| @ -1,37 +1,48 @@ | ||||
| """Google Provider tasks""" | ||||
|  | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from dramatiq.actor import actor | ||||
|  | ||||
| from authentik.enterprise.providers.google_workspace.models import GoogleWorkspaceProvider | ||||
| from authentik.events.system_tasks import SystemTask | ||||
| from authentik.lib.sync.outgoing.exceptions import TransientSyncException | ||||
| from authentik.lib.sync.outgoing.tasks import SyncTasks | ||||
| from authentik.root.celery import CELERY_APP | ||||
|  | ||||
| sync_tasks = SyncTasks(GoogleWorkspaceProvider) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | ||||
| @actor(description=_("Sync Google Workspace provider objects.")) | ||||
| def google_workspace_sync_objects(*args, **kwargs): | ||||
|     return sync_tasks.sync_objects(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task( | ||||
|     base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True | ||||
| ) | ||||
| def google_workspace_sync(self, provider_pk: int, *args, **kwargs): | ||||
| @actor(description=_("Full sync for Google Workspace provider.")) | ||||
| def google_workspace_sync(provider_pk: int, *args, **kwargs): | ||||
|     """Run full sync for Google Workspace provider""" | ||||
|     return sync_tasks.sync_single(self, provider_pk, google_workspace_sync_objects) | ||||
|     return sync_tasks.sync(provider_pk, google_workspace_sync_objects) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task() | ||||
| def google_workspace_sync_all(): | ||||
|     return sync_tasks.sync_all(google_workspace_sync) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | ||||
| @actor(description=_("Sync a direct object (user, group) for Google Workspace provider.")) | ||||
| def google_workspace_sync_direct(*args, **kwargs): | ||||
|     return sync_tasks.sync_signal_direct(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | ||||
| @actor( | ||||
|     description=_( | ||||
|         "Dispatch syncs for a direct object (user, group) for Google Workspace providers." | ||||
|     ) | ||||
| ) | ||||
| def google_workspace_sync_direct_dispatch(*args, **kwargs): | ||||
|     return sync_tasks.sync_signal_direct_dispatch(google_workspace_sync_direct, *args, **kwargs) | ||||
|  | ||||
|  | ||||
| @actor(description=_("Sync a related object (memberships) for Google Workspace provider.")) | ||||
| def google_workspace_sync_m2m(*args, **kwargs): | ||||
|     return sync_tasks.sync_signal_m2m(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| @actor( | ||||
|     description=_( | ||||
|         "Dispatch syncs for a related object (memberships) for Google Workspace providers." | ||||
|     ) | ||||
| ) | ||||
| def google_workspace_sync_m2m_dispatch(*args, **kwargs): | ||||
|     return sync_tasks.sync_signal_m2m_dispatch(google_workspace_sync_m2m, *args, **kwargs) | ||||
|  | ||||
| @ -324,7 +324,7 @@ class GoogleWorkspaceGroupTests(TestCase): | ||||
|             "authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials", | ||||
|             MagicMock(return_value={"developerKey": self.api_key, "http": http}), | ||||
|         ): | ||||
|             google_workspace_sync.delay(self.provider.pk).get() | ||||
|             google_workspace_sync.send(self.provider.pk).get_result() | ||||
|             self.assertTrue( | ||||
|                 GoogleWorkspaceProviderGroup.objects.filter( | ||||
|                     group=different_group, provider=self.provider | ||||
|  | ||||
| @ -302,7 +302,7 @@ class GoogleWorkspaceUserTests(TestCase): | ||||
|             "authentik.enterprise.providers.google_workspace.models.GoogleWorkspaceProvider.google_credentials", | ||||
|             MagicMock(return_value={"developerKey": self.api_key, "http": http}), | ||||
|         ): | ||||
|             google_workspace_sync.delay(self.provider.pk).get() | ||||
|             google_workspace_sync.send(self.provider.pk).get_result() | ||||
|             self.assertTrue( | ||||
|                 GoogleWorkspaceProviderUser.objects.filter( | ||||
|                     user=different_user, provider=self.provider | ||||
|  | ||||
| @ -53,5 +53,5 @@ class MicrosoftEntraProviderViewSet(OutgoingSyncProviderStatusMixin, UsedByMixin | ||||
|     ] | ||||
|     search_fields = ["name"] | ||||
|     ordering = ["name"] | ||||
|     sync_single_task = microsoft_entra_sync | ||||
|     sync_task = microsoft_entra_sync | ||||
|     sync_objects_task = microsoft_entra_sync_objects | ||||
|  | ||||
| @ -8,6 +8,7 @@ from django.db import models | ||||
| from django.db.models import QuerySet | ||||
| from django.templatetags.static import static | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from dramatiq.actor import Actor | ||||
| from rest_framework.serializers import Serializer | ||||
|  | ||||
| from authentik.core.models import ( | ||||
| @ -99,6 +100,12 @@ class MicrosoftEntraProvider(OutgoingSyncProvider, BackchannelProvider): | ||||
|         help_text=_("Property mappings used for group creation/updating."), | ||||
|     ) | ||||
|  | ||||
|     @property | ||||
|     def sync_actor(self) -> Actor: | ||||
|         from authentik.enterprise.providers.microsoft_entra.tasks import microsoft_entra_sync | ||||
|  | ||||
|         return microsoft_entra_sync | ||||
|  | ||||
|     def client_for_model( | ||||
|         self, | ||||
|         model: type[User | Group | MicrosoftEntraProviderUser | MicrosoftEntraProviderGroup], | ||||
|  | ||||
| @ -1,13 +0,0 @@ | ||||
| """Microsoft Entra provider task Settings""" | ||||
|  | ||||
| from celery.schedules import crontab | ||||
|  | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
|  | ||||
| CELERY_BEAT_SCHEDULE = { | ||||
|     "providers_microsoft_entra_sync": { | ||||
|         "task": "authentik.enterprise.providers.microsoft_entra.tasks.microsoft_entra_sync_all", | ||||
|         "schedule": crontab(minute=fqdn_rand("microsoft_entra_sync_all"), hour="*/4"), | ||||
|         "options": {"queue": "authentik_scheduled"}, | ||||
|     }, | ||||
| } | ||||
| @ -2,15 +2,13 @@ | ||||
|  | ||||
| from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider | ||||
| from authentik.enterprise.providers.microsoft_entra.tasks import ( | ||||
|     microsoft_entra_sync, | ||||
|     microsoft_entra_sync_direct, | ||||
|     microsoft_entra_sync_m2m, | ||||
|     microsoft_entra_sync_direct_dispatch, | ||||
|     microsoft_entra_sync_m2m_dispatch, | ||||
| ) | ||||
| from authentik.lib.sync.outgoing.signals import register_signals | ||||
|  | ||||
| register_signals( | ||||
|     MicrosoftEntraProvider, | ||||
|     task_sync_single=microsoft_entra_sync, | ||||
|     task_sync_direct=microsoft_entra_sync_direct, | ||||
|     task_sync_m2m=microsoft_entra_sync_m2m, | ||||
|     task_sync_direct_dispatch=microsoft_entra_sync_direct_dispatch, | ||||
|     task_sync_m2m_dispatch=microsoft_entra_sync_m2m_dispatch, | ||||
| ) | ||||
|  | ||||
| @ -1,37 +1,46 @@ | ||||
| """Microsoft Entra Provider tasks""" | ||||
|  | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from dramatiq.actor import actor | ||||
|  | ||||
| from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider | ||||
| from authentik.events.system_tasks import SystemTask | ||||
| from authentik.lib.sync.outgoing.exceptions import TransientSyncException | ||||
| from authentik.lib.sync.outgoing.tasks import SyncTasks | ||||
| from authentik.root.celery import CELERY_APP | ||||
|  | ||||
| sync_tasks = SyncTasks(MicrosoftEntraProvider) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | ||||
| @actor(description=_("Sync Microsoft Entra provider objects.")) | ||||
| def microsoft_entra_sync_objects(*args, **kwargs): | ||||
|     return sync_tasks.sync_objects(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task( | ||||
|     base=SystemTask, bind=True, autoretry_for=(TransientSyncException,), retry_backoff=True | ||||
| ) | ||||
| def microsoft_entra_sync(self, provider_pk: int, *args, **kwargs): | ||||
| @actor(description=_("Full sync for Microsoft Entra provider.")) | ||||
| def microsoft_entra_sync(provider_pk: int, *args, **kwargs): | ||||
|     """Run full sync for Microsoft Entra provider""" | ||||
|     return sync_tasks.sync_single(self, provider_pk, microsoft_entra_sync_objects) | ||||
|     return sync_tasks.sync(provider_pk, microsoft_entra_sync_objects) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task() | ||||
| def microsoft_entra_sync_all(): | ||||
|     return sync_tasks.sync_all(microsoft_entra_sync) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | ||||
| @actor(description=_("Sync a direct object (user, group) for Microsoft Entra provider.")) | ||||
| def microsoft_entra_sync_direct(*args, **kwargs): | ||||
|     return sync_tasks.sync_signal_direct(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(autoretry_for=(TransientSyncException,), retry_backoff=True) | ||||
| @actor( | ||||
|     description=_("Dispatch syncs for a direct object (user, group) for Microsoft Entra providers.") | ||||
| ) | ||||
| def microsoft_entra_sync_direct_dispatch(*args, **kwargs): | ||||
|     return sync_tasks.sync_signal_direct_dispatch(microsoft_entra_sync_direct, *args, **kwargs) | ||||
|  | ||||
|  | ||||
| @actor(description=_("Sync a related object (memberships) for Microsoft Entra provider.")) | ||||
| def microsoft_entra_sync_m2m(*args, **kwargs): | ||||
|     return sync_tasks.sync_signal_m2m(*args, **kwargs) | ||||
|  | ||||
|  | ||||
| @actor( | ||||
|     description=_( | ||||
|         "Dispatch syncs for a related object (memberships) for Microsoft Entra providers." | ||||
|     ) | ||||
| ) | ||||
| def microsoft_entra_sync_m2m_dispatch(*args, **kwargs): | ||||
|     return sync_tasks.sync_signal_m2m_dispatch(microsoft_entra_sync_m2m, *args, **kwargs) | ||||
|  | ||||
| @ -252,9 +252,13 @@ class MicrosoftEntraGroupTests(TestCase): | ||||
|             member_add.assert_called_once() | ||||
|             self.assertEqual( | ||||
|                 member_add.call_args[0][0].odata_id, | ||||
|                 f"https://graph.microsoft.com/v1.0/directoryObjects/{MicrosoftEntraProviderUser.objects.filter( | ||||
|                 f"https://graph.microsoft.com/v1.0/directoryObjects/{ | ||||
|                     MicrosoftEntraProviderUser.objects.filter( | ||||
|                         provider=self.provider, | ||||
|                     ).first().microsoft_id}", | ||||
|                     ) | ||||
|                     .first() | ||||
|                     .microsoft_id | ||||
|                 }", | ||||
|             ) | ||||
|  | ||||
|     def test_group_create_member_remove(self): | ||||
| @ -311,9 +315,13 @@ class MicrosoftEntraGroupTests(TestCase): | ||||
|             member_add.assert_called_once() | ||||
|             self.assertEqual( | ||||
|                 member_add.call_args[0][0].odata_id, | ||||
|                 f"https://graph.microsoft.com/v1.0/directoryObjects/{MicrosoftEntraProviderUser.objects.filter( | ||||
|                 f"https://graph.microsoft.com/v1.0/directoryObjects/{ | ||||
|                     MicrosoftEntraProviderUser.objects.filter( | ||||
|                         provider=self.provider, | ||||
|                     ).first().microsoft_id}", | ||||
|                     ) | ||||
|                     .first() | ||||
|                     .microsoft_id | ||||
|                 }", | ||||
|             ) | ||||
|             member_remove.assert_called_once() | ||||
|  | ||||
| @ -413,7 +421,7 @@ class MicrosoftEntraGroupTests(TestCase): | ||||
|                 ), | ||||
|             ) as group_list, | ||||
|         ): | ||||
|             microsoft_entra_sync.delay(self.provider.pk).get() | ||||
|             microsoft_entra_sync.send(self.provider.pk).get_result() | ||||
|             self.assertTrue( | ||||
|                 MicrosoftEntraProviderGroup.objects.filter( | ||||
|                     group=different_group, provider=self.provider | ||||
|  | ||||
| @ -397,7 +397,7 @@ class MicrosoftEntraUserTests(APITestCase): | ||||
|                 AsyncMock(return_value=GroupCollectionResponse(value=[])), | ||||
|             ), | ||||
|         ): | ||||
|             microsoft_entra_sync.delay(self.provider.pk).get() | ||||
|             microsoft_entra_sync.send(self.provider.pk).get_result() | ||||
|             self.assertTrue( | ||||
|                 MicrosoftEntraProviderUser.objects.filter( | ||||
|                     user=different_user, provider=self.provider | ||||
|  | ||||
| @ -17,6 +17,7 @@ from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.lib.models import CreatedUpdatedModel | ||||
| from authentik.lib.utils.time import timedelta_from_string, timedelta_string_validator | ||||
| from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider | ||||
| from authentik.tasks.models import TasksModel | ||||
|  | ||||
|  | ||||
| class EventTypes(models.TextChoices): | ||||
| @ -42,7 +43,7 @@ class SSFEventStatus(models.TextChoices): | ||||
|     SENT = "sent" | ||||
|  | ||||
|  | ||||
| class SSFProvider(BackchannelProvider): | ||||
| class SSFProvider(TasksModel, BackchannelProvider): | ||||
|     """Shared Signals Framework provider to allow applications to | ||||
|     receive user events from authentik.""" | ||||
|  | ||||
|  | ||||
| @ -18,7 +18,7 @@ from authentik.enterprise.providers.ssf.models import ( | ||||
|     EventTypes, | ||||
|     SSFProvider, | ||||
| ) | ||||
| from authentik.enterprise.providers.ssf.tasks import send_ssf_event | ||||
| from authentik.enterprise.providers.ssf.tasks import send_ssf_events | ||||
| from authentik.events.middleware import audit_ignore | ||||
| from authentik.stages.authenticator.models import Device | ||||
| from authentik.stages.authenticator_duo.models import DuoDevice | ||||
| @ -66,7 +66,7 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi | ||||
|  | ||||
|     As this signal is also triggered with a regular logout, we can't be sure | ||||
|     if the session has been deleted by an admin or by the user themselves.""" | ||||
|     send_ssf_event( | ||||
|     send_ssf_events( | ||||
|         EventTypes.CAEP_SESSION_REVOKED, | ||||
|         { | ||||
|             "initiating_entity": "user", | ||||
| @ -88,7 +88,7 @@ def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSessi | ||||
| @receiver(password_changed) | ||||
| def ssf_password_changed_cred_change(sender, user: User, password: str | None, **_): | ||||
|     """Credential change trigger (password changed)""" | ||||
|     send_ssf_event( | ||||
|     send_ssf_events( | ||||
|         EventTypes.CAEP_CREDENTIAL_CHANGE, | ||||
|         { | ||||
|             "credential_type": "password", | ||||
| @ -126,7 +126,7 @@ def ssf_device_post_save(sender: type[Model], instance: Device, created: bool, * | ||||
|     } | ||||
|     if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID: | ||||
|         data["fido2_aaguid"] = instance.aaguid | ||||
|     send_ssf_event( | ||||
|     send_ssf_events( | ||||
|         EventTypes.CAEP_CREDENTIAL_CHANGE, | ||||
|         data, | ||||
|         sub_id={ | ||||
| @ -153,7 +153,7 @@ def ssf_device_post_delete(sender: type[Model], instance: Device, **_): | ||||
|     } | ||||
|     if isinstance(instance, WebAuthnDevice) and instance.aaguid != UNKNOWN_DEVICE_TYPE_AAGUID: | ||||
|         data["fido2_aaguid"] = instance.aaguid | ||||
|     send_ssf_event( | ||||
|     send_ssf_events( | ||||
|         EventTypes.CAEP_CREDENTIAL_CHANGE, | ||||
|         data, | ||||
|         sub_id={ | ||||
|  | ||||
| @ -1,7 +1,11 @@ | ||||
| from celery import group | ||||
| from typing import Any | ||||
| from uuid import UUID | ||||
|  | ||||
| from django.http import HttpRequest | ||||
| from django.utils.timezone import now | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from django_dramatiq_postgres.middleware import CurrentTask | ||||
| from dramatiq.actor import actor | ||||
| from requests.exceptions import RequestException | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| @ -13,19 +17,16 @@ from authentik.enterprise.providers.ssf.models import ( | ||||
|     Stream, | ||||
|     StreamEvent, | ||||
| ) | ||||
| from authentik.events.logs import LogEvent | ||||
| from authentik.events.models import TaskStatus | ||||
| from authentik.events.system_tasks import SystemTask | ||||
| from authentik.lib.utils.http import get_http_session | ||||
| from authentik.lib.utils.time import timedelta_from_string | ||||
| from authentik.policies.engine import PolicyEngine | ||||
| from authentik.root.celery import CELERY_APP | ||||
| from authentik.tasks.models import Task | ||||
|  | ||||
| session = get_http_session() | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| def send_ssf_event( | ||||
| def send_ssf_events( | ||||
|     event_type: EventTypes, | ||||
|     data: dict, | ||||
|     stream_filter: dict | None = None, | ||||
| @ -33,7 +34,7 @@ def send_ssf_event( | ||||
|     **extra_data, | ||||
| ): | ||||
|     """Wrapper to send an SSF event to multiple streams""" | ||||
|     payload = [] | ||||
|     events_data = {} | ||||
|     if not stream_filter: | ||||
|         stream_filter = {} | ||||
|     stream_filter["events_requested__contains"] = [event_type] | ||||
| @ -41,16 +42,22 @@ def send_ssf_event( | ||||
|         extra_data.setdefault("txn", request.request_id) | ||||
|     for stream in Stream.objects.filter(**stream_filter): | ||||
|         event_data = stream.prepare_event_payload(event_type, data, **extra_data) | ||||
|         payload.append((str(stream.uuid), event_data)) | ||||
|     return _send_ssf_event.delay(payload) | ||||
|         events_data[stream.uuid] = event_data | ||||
|     ssf_events_dispatch.send(events_data) | ||||
|  | ||||
|  | ||||
| def _check_app_access(stream_uuid: str, event_data: dict) -> bool: | ||||
| @actor(description=_("Dispatch SSF events.")) | ||||
| def ssf_events_dispatch(events_data: dict[str, dict[str, Any]]): | ||||
|     for stream_uuid, event_data in events_data.items(): | ||||
|         stream = Stream.objects.filter(pk=stream_uuid).first() | ||||
|         if not stream: | ||||
|             continue | ||||
|         send_ssf_event.send_with_options(args=(stream_uuid, event_data), rel_obj=stream.provider) | ||||
|  | ||||
|  | ||||
| def _check_app_access(stream: Stream, event_data: dict) -> bool: | ||||
|     """Check if event is related to user and if so, check | ||||
|     if the user has access to the application""" | ||||
|     stream = Stream.objects.filter(pk=stream_uuid).first() | ||||
|     if not stream: | ||||
|         return False | ||||
|     # `event_data` is a dict version of a StreamEvent | ||||
|     sub_id = event_data.get("payload", {}).get("sub_id", {}) | ||||
|     email = sub_id.get("user", {}).get("email", None) | ||||
| @ -65,42 +72,22 @@ def _check_app_access(stream_uuid: str, event_data: dict) -> bool: | ||||
|     return engine.passing | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task() | ||||
| def _send_ssf_event(event_data: list[tuple[str, dict]]): | ||||
|     tasks = [] | ||||
|     for stream, data in event_data: | ||||
|         if not _check_app_access(stream, data): | ||||
|             continue | ||||
|         event = StreamEvent.objects.create(**data) | ||||
|         tasks.extend(send_single_ssf_event(stream, str(event.uuid))) | ||||
|     main_task = group(*tasks) | ||||
|     main_task() | ||||
| @actor(description=_("Send an SSF event.")) | ||||
| def send_ssf_event(stream_uuid: UUID, event_data: dict[str, Any]): | ||||
|     self: Task = CurrentTask.get_task() | ||||
|  | ||||
|  | ||||
| def send_single_ssf_event(stream_id: str, evt_id: str): | ||||
|     stream = Stream.objects.filter(pk=stream_id).first() | ||||
|     stream = Stream.objects.filter(pk=stream_uuid).first() | ||||
|     if not stream: | ||||
|         return | ||||
|     event = StreamEvent.objects.filter(pk=evt_id).first() | ||||
|     if not event: | ||||
|     if not _check_app_access(stream, event_data): | ||||
|         return | ||||
|     event = StreamEvent.objects.create(**event_data) | ||||
|     self.set_uid(event.pk) | ||||
|     if event.status == SSFEventStatus.SENT: | ||||
|         return | ||||
|     if stream.delivery_method == DeliveryMethods.RISC_PUSH: | ||||
|         return [ssf_push_event.si(str(event.pk))] | ||||
|     return [] | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(bind=True, base=SystemTask) | ||||
| def ssf_push_event(self: SystemTask, event_id: str): | ||||
|     self.save_on_success = False | ||||
|     event = StreamEvent.objects.filter(pk=event_id).first() | ||||
|     if not event: | ||||
|         return | ||||
|     self.set_uid(event_id) | ||||
|     if event.status == SSFEventStatus.SENT: | ||||
|         self.set_status(TaskStatus.SUCCESSFUL) | ||||
|     if stream.delivery_method != DeliveryMethods.RISC_PUSH: | ||||
|         return | ||||
|  | ||||
|     try: | ||||
|         response = session.post( | ||||
|             event.stream.endpoint_url, | ||||
| @ -110,26 +97,17 @@ def ssf_push_event(self: SystemTask, event_id: str): | ||||
|         response.raise_for_status() | ||||
|         event.status = SSFEventStatus.SENT | ||||
|         event.save() | ||||
|         self.set_status(TaskStatus.SUCCESSFUL) | ||||
|         return | ||||
|     except RequestException as exc: | ||||
|         LOGGER.warning("Failed to send SSF event", exc=exc) | ||||
|         self.set_status(TaskStatus.ERROR) | ||||
|         attrs = {} | ||||
|         if exc.response: | ||||
|             attrs["response"] = { | ||||
|                 "content": exc.response.text, | ||||
|                 "status": exc.response.status_code, | ||||
|             } | ||||
|         self.set_error( | ||||
|             exc, | ||||
|             LogEvent( | ||||
|                 _("Failed to send request"), | ||||
|                 log_level="warning", | ||||
|                 logger=self.__name__, | ||||
|                 attributes=attrs, | ||||
|             ), | ||||
|         ) | ||||
|         self.warning(exc) | ||||
|         self.warning("Failed to send request", **attrs) | ||||
|         # Re-up the expiry of the stream event | ||||
|         event.expires = now() + timedelta_from_string(event.stream.provider.event_retention) | ||||
|         event.status = SSFEventStatus.PENDING_FAILED | ||||
|  | ||||
| @ -13,7 +13,7 @@ from authentik.enterprise.providers.ssf.models import ( | ||||
|     SSFProvider, | ||||
|     Stream, | ||||
| ) | ||||
| from authentik.enterprise.providers.ssf.tasks import send_ssf_event | ||||
| from authentik.enterprise.providers.ssf.tasks import send_ssf_events | ||||
| from authentik.enterprise.providers.ssf.views.base import SSFView | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| @ -109,7 +109,7 @@ class StreamView(SSFView): | ||||
|                 "User does not have permission to create stream for this provider." | ||||
|             ) | ||||
|         instance: Stream = stream.save(provider=self.provider) | ||||
|         send_ssf_event( | ||||
|         send_ssf_events( | ||||
|             EventTypes.SET_VERIFICATION, | ||||
|             { | ||||
|                 "state": None, | ||||
|  | ||||
| @ -6,7 +6,7 @@ from djangoql.ast import Name | ||||
| from djangoql.exceptions import DjangoQLError | ||||
| from djangoql.queryset import apply_search | ||||
| from djangoql.schema import DjangoQLSchema | ||||
| from rest_framework.filters import SearchFilter | ||||
| from rest_framework.filters import BaseFilterBackend, SearchFilter | ||||
| from rest_framework.request import Request | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| @ -39,19 +39,21 @@ class BaseSchema(DjangoQLSchema): | ||||
|         return super().resolve_name(name) | ||||
|  | ||||
|  | ||||
| class QLSearch(SearchFilter): | ||||
| class QLSearch(BaseFilterBackend): | ||||
|     """rest_framework search filter which uses DjangoQL""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self._fallback = SearchFilter() | ||||
|  | ||||
|     @property | ||||
|     def enabled(self): | ||||
|         return apps.get_app_config("authentik_enterprise").enabled() | ||||
|  | ||||
|     def get_search_terms(self, request) -> str: | ||||
|         """ | ||||
|         Search terms are set by a ?search=... query parameter, | ||||
|         and may be comma and/or whitespace delimited. | ||||
|         """ | ||||
|         params = request.query_params.get(self.search_param, "") | ||||
|     def get_search_terms(self, request: Request) -> str: | ||||
|         """Search terms are set by a ?search=... query parameter, | ||||
|         and may be comma and/or whitespace delimited.""" | ||||
|         params = request.query_params.get("search", "") | ||||
|         params = params.replace("\x00", "")  # strip null characters | ||||
|         return params | ||||
|  | ||||
| @ -70,9 +72,9 @@ class QLSearch(SearchFilter): | ||||
|         search_query = self.get_search_terms(request) | ||||
|         schema = self.get_schema(request, view) | ||||
|         if len(search_query) == 0 or not self.enabled: | ||||
|             return super().filter_queryset(request, queryset, view) | ||||
|             return self._fallback.filter_queryset(request, queryset, view) | ||||
|         try: | ||||
|             return apply_search(queryset, search_query, schema=schema) | ||||
|         except DjangoQLError as exc: | ||||
|             LOGGER.debug("Failed to parse search expression", exc=exc) | ||||
|             return super().filter_queryset(request, queryset, view) | ||||
|             return self._fallback.filter_queryset(request, queryset, view) | ||||
|  | ||||
| @ -57,7 +57,7 @@ class QLTest(APITestCase): | ||||
|         ) | ||||
|         self.assertEqual(res.status_code, 200) | ||||
|         content = loads(res.content) | ||||
|         self.assertGreaterEqual(content["pagination"]["count"], 1) | ||||
|         self.assertEqual(content["pagination"]["count"], 1) | ||||
|         self.assertEqual(content["results"][0]["username"], self.user.username) | ||||
|  | ||||
|     def test_search_json(self): | ||||
|  | ||||
| @ -1,17 +1,5 @@ | ||||
| """Enterprise additional settings""" | ||||
|  | ||||
| from celery.schedules import crontab | ||||
|  | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
|  | ||||
| CELERY_BEAT_SCHEDULE = { | ||||
|     "enterprise_update_usage": { | ||||
|         "task": "authentik.enterprise.tasks.enterprise_update_usage", | ||||
|         "schedule": crontab(minute=fqdn_rand("enterprise_update_usage"), hour="*/2"), | ||||
|         "options": {"queue": "authentik_scheduled"}, | ||||
|     } | ||||
| } | ||||
|  | ||||
| TENANT_APPS = [ | ||||
|     "authentik.enterprise.audit", | ||||
|     "authentik.enterprise.policies.unique_password", | ||||
|  | ||||
| @ -10,6 +10,7 @@ from django.utils.timezone import get_current_timezone | ||||
| from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE | ||||
| from authentik.enterprise.models import License | ||||
| from authentik.enterprise.tasks import enterprise_update_usage | ||||
| from authentik.tasks.schedules.models import Schedule | ||||
|  | ||||
|  | ||||
| @receiver(pre_save, sender=License) | ||||
| @ -26,7 +27,7 @@ def pre_save_license(sender: type[License], instance: License, **_): | ||||
| def post_save_license(sender: type[License], instance: License, **_): | ||||
|     """Trigger license usage calculation when license is saved""" | ||||
|     cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) | ||||
|     enterprise_update_usage.delay() | ||||
|     Schedule.dispatch_by_actor(enterprise_update_usage) | ||||
|  | ||||
|  | ||||
| @receiver(post_delete, sender=License) | ||||
|  | ||||
| @ -1,14 +1,11 @@ | ||||
| """Enterprise tasks""" | ||||
|  | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from dramatiq.actor import actor | ||||
|  | ||||
| from authentik.enterprise.license import LicenseKey | ||||
| from authentik.events.models import TaskStatus | ||||
| from authentik.events.system_tasks import SystemTask, prefill_task | ||||
| from authentik.root.celery import CELERY_APP | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(bind=True, base=SystemTask) | ||||
| @prefill_task | ||||
| def enterprise_update_usage(self: SystemTask): | ||||
|     """Update enterprise license status""" | ||||
| @actor(description=_("Update enterprise license status.")) | ||||
| def enterprise_update_usage(): | ||||
|     LicenseKey.get_total().record_usage() | ||||
|     self.set_status(TaskStatus.SUCCESSFUL) | ||||
|  | ||||
| @ -1,104 +0,0 @@ | ||||
| """Tasks API""" | ||||
|  | ||||
| from importlib import import_module | ||||
|  | ||||
| from django.contrib import messages | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from drf_spectacular.types import OpenApiTypes | ||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema | ||||
| from rest_framework.decorators import action | ||||
| from rest_framework.fields import ( | ||||
|     CharField, | ||||
|     ChoiceField, | ||||
|     DateTimeField, | ||||
|     FloatField, | ||||
|     SerializerMethodField, | ||||
| ) | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
| from rest_framework.viewsets import ReadOnlyModelViewSet | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.api.utils import ModelSerializer | ||||
| from authentik.events.logs import LogEventSerializer | ||||
| from authentik.events.models import SystemTask, TaskStatus | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| class SystemTaskSerializer(ModelSerializer): | ||||
|     """Serialize TaskInfo and TaskResult""" | ||||
|  | ||||
|     name = CharField() | ||||
|     full_name = SerializerMethodField() | ||||
|     uid = CharField(required=False) | ||||
|     description = CharField() | ||||
|     start_timestamp = DateTimeField(read_only=True) | ||||
|     finish_timestamp = DateTimeField(read_only=True) | ||||
|     duration = FloatField(read_only=True) | ||||
|  | ||||
|     status = ChoiceField(choices=[(x.value, x.name) for x in TaskStatus]) | ||||
|     messages = LogEventSerializer(many=True) | ||||
|  | ||||
|     def get_full_name(self, instance: SystemTask) -> str: | ||||
|         """Get full name with UID""" | ||||
|         if instance.uid: | ||||
|             return f"{instance.name}:{instance.uid}" | ||||
|         return instance.name | ||||
|  | ||||
|     class Meta: | ||||
|         model = SystemTask | ||||
|         fields = [ | ||||
|             "uuid", | ||||
|             "name", | ||||
|             "full_name", | ||||
|             "uid", | ||||
|             "description", | ||||
|             "start_timestamp", | ||||
|             "finish_timestamp", | ||||
|             "duration", | ||||
|             "status", | ||||
|             "messages", | ||||
|             "expires", | ||||
|             "expiring", | ||||
|         ] | ||||
|  | ||||
|  | ||||
| class SystemTaskViewSet(ReadOnlyModelViewSet): | ||||
|     """Read-only view set that returns all background tasks""" | ||||
|  | ||||
|     queryset = SystemTask.objects.all() | ||||
|     serializer_class = SystemTaskSerializer | ||||
|     filterset_fields = ["name", "uid", "status"] | ||||
|     ordering = ["name", "uid", "status"] | ||||
|     search_fields = ["name", "description", "uid", "status"] | ||||
|  | ||||
|     @permission_required(None, ["authentik_events.run_task"]) | ||||
|     @extend_schema( | ||||
|         request=OpenApiTypes.NONE, | ||||
|         responses={ | ||||
|             204: OpenApiResponse(description="Task retried successfully"), | ||||
|             404: OpenApiResponse(description="Task not found"), | ||||
|             500: OpenApiResponse(description="Failed to retry task"), | ||||
|         }, | ||||
|     ) | ||||
|     @action(detail=True, methods=["POST"], permission_classes=[]) | ||||
|     def run(self, request: Request, pk=None) -> Response: | ||||
|         """Run task""" | ||||
|         task: SystemTask = self.get_object() | ||||
|         try: | ||||
|             task_module = import_module(task.task_call_module) | ||||
|             task_func = getattr(task_module, task.task_call_func) | ||||
|             LOGGER.info("Running task", task=task_func) | ||||
|             task_func.delay(*task.task_call_args, **task.task_call_kwargs) | ||||
|             messages.success( | ||||
|                 self.request, | ||||
|                 _("Successfully started task {name}.".format_map({"name": task.name})), | ||||
|             ) | ||||
|             return Response(status=204) | ||||
|         except (ImportError, AttributeError) as exc:  # pragma: no cover | ||||
|             LOGGER.warning("Failed to run task, remove state", task=task.name, exc=exc) | ||||
|             # if we get an import error, the module path has probably changed | ||||
|             task.delete() | ||||
|             return Response(status=500) | ||||
| @ -1,12 +1,11 @@ | ||||
| """authentik events app""" | ||||
|  | ||||
| from celery.schedules import crontab | ||||
| from prometheus_client import Gauge, Histogram | ||||
|  | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
| from authentik.lib.config import CONFIG, ENV_PREFIX | ||||
| from authentik.lib.utils.reflection import path_to_class | ||||
| from authentik.root.celery import CELERY_APP | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
| from authentik.tasks.schedules.lib import ScheduleSpec | ||||
|  | ||||
| # TODO: Deprecated metric - remove in 2024.2 or later | ||||
| GAUGE_TASKS = Gauge( | ||||
| @ -35,6 +34,17 @@ class AuthentikEventsConfig(ManagedAppConfig): | ||||
|     verbose_name = "authentik Events" | ||||
|     default = True | ||||
|  | ||||
|     @property | ||||
|     def tenant_schedule_specs(self) -> list[ScheduleSpec]: | ||||
|         from authentik.events.tasks import notification_cleanup | ||||
|  | ||||
|         return [ | ||||
|             ScheduleSpec( | ||||
|                 actor=notification_cleanup, | ||||
|                 crontab=f"{fqdn_rand('notification_cleanup')} */8 * * *", | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
|     @ManagedAppConfig.reconcile_global | ||||
|     def check_deprecations(self): | ||||
|         """Check for config deprecations""" | ||||
| @ -56,41 +66,3 @@ class AuthentikEventsConfig(ManagedAppConfig): | ||||
|                 replacement_env=replace_env, | ||||
|                 message=msg, | ||||
|             ).save() | ||||
|  | ||||
|     @ManagedAppConfig.reconcile_tenant | ||||
|     def prefill_tasks(self): | ||||
|         """Prefill tasks""" | ||||
|         from authentik.events.models import SystemTask | ||||
|         from authentik.events.system_tasks import _prefill_tasks | ||||
|  | ||||
|         for task in _prefill_tasks: | ||||
|             if SystemTask.objects.filter(name=task.name).exists(): | ||||
|                 continue | ||||
|             task.save() | ||||
|             self.logger.debug("prefilled task", task_name=task.name) | ||||
|  | ||||
|     @ManagedAppConfig.reconcile_tenant | ||||
|     def run_scheduled_tasks(self): | ||||
|         """Run schedule tasks which are behind schedule (only applies | ||||
|         to tasks of which we keep metrics)""" | ||||
|         from authentik.events.models import TaskStatus | ||||
|         from authentik.events.system_tasks import SystemTask as CelerySystemTask | ||||
|  | ||||
|         for task in CELERY_APP.conf["beat_schedule"].values(): | ||||
|             schedule = task["schedule"] | ||||
|             if not isinstance(schedule, crontab): | ||||
|                 continue | ||||
|             task_class: CelerySystemTask = path_to_class(task["task"]) | ||||
|             if not isinstance(task_class, CelerySystemTask): | ||||
|                 continue | ||||
|             db_task = task_class.db() | ||||
|             if not db_task: | ||||
|                 continue | ||||
|             due, _ = schedule.is_due(db_task.finish_timestamp) | ||||
|             if due or db_task.status == TaskStatus.UNKNOWN: | ||||
|                 self.logger.debug("Running past-due scheduled task", task=task["task"]) | ||||
|                 task_class.apply_async( | ||||
|                     args=task.get("args", None), | ||||
|                     kwargs=task.get("kwargs", None), | ||||
|                     **task.get("options", {}), | ||||
|                 ) | ||||
|  | ||||
							
								
								
									
										22
									
								
								authentik/events/migrations/0011_alter_systemtask_options.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								authentik/events/migrations/0011_alter_systemtask_options.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,22 @@ | ||||
| # Generated by Django 5.1.11 on 2025-06-24 15:36 | ||||
|  | ||||
| from django.db import migrations | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_events", "0010_rename_group_notificationrule_destination_group_and_more"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AlterModelOptions( | ||||
|             name="systemtask", | ||||
|             options={ | ||||
|                 "default_permissions": (), | ||||
|                 "permissions": (), | ||||
|                 "verbose_name": "System Task", | ||||
|                 "verbose_name_plural": "System Tasks", | ||||
|             }, | ||||
|         ), | ||||
|     ] | ||||
| @ -5,12 +5,11 @@ from datetime import timedelta | ||||
| from difflib import get_close_matches | ||||
| from functools import lru_cache | ||||
| from inspect import currentframe | ||||
| from smtplib import SMTPException | ||||
| from typing import Any | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from django.apps import apps | ||||
| from django.db import connection, models | ||||
| from django.db import models | ||||
| from django.http import HttpRequest | ||||
| from django.http.request import QueryDict | ||||
| from django.utils.timezone import now | ||||
| @ -27,7 +26,6 @@ from authentik.core.middleware import ( | ||||
|     SESSION_KEY_IMPERSONATE_USER, | ||||
| ) | ||||
| from authentik.core.models import ExpiringModel, Group, PropertyMapping, User | ||||
| from authentik.events.apps import GAUGE_TASKS, SYSTEM_TASK_STATUS, SYSTEM_TASK_TIME | ||||
| from authentik.events.context_processors.base import get_context_processors | ||||
| from authentik.events.utils import ( | ||||
|     cleanse_dict, | ||||
| @ -43,6 +41,7 @@ from authentik.lib.utils.time import timedelta_from_string | ||||
| from authentik.policies.models import PolicyBindingModel | ||||
| from authentik.root.middleware import ClientIPMiddleware | ||||
| from authentik.stages.email.utils import TemplateEmailMessage | ||||
| from authentik.tasks.models import TasksModel | ||||
| from authentik.tenants.models import Tenant | ||||
| from authentik.tenants.utils import get_current_tenant | ||||
|  | ||||
| @ -267,7 +266,8 @@ class Event(SerializerModel, ExpiringModel): | ||||
|             models.Index(fields=["created"]), | ||||
|             models.Index(fields=["client_ip"]), | ||||
|             models.Index( | ||||
|                 models.F("context__authorized_application"), name="authentik_e_ctx_app__idx" | ||||
|                 models.F("context__authorized_application"), | ||||
|                 name="authentik_e_ctx_app__idx", | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
| @ -281,7 +281,7 @@ class TransportMode(models.TextChoices): | ||||
|     EMAIL = "email", _("Email") | ||||
|  | ||||
|  | ||||
| class NotificationTransport(SerializerModel): | ||||
| class NotificationTransport(TasksModel, SerializerModel): | ||||
|     """Action which is executed when a Rule matches""" | ||||
|  | ||||
|     uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) | ||||
| @ -446,6 +446,8 @@ class NotificationTransport(SerializerModel): | ||||
|  | ||||
|     def send_email(self, notification: "Notification") -> list[str]: | ||||
|         """Send notification via global email configuration""" | ||||
|         from authentik.stages.email.tasks import send_mail | ||||
|  | ||||
|         if notification.user.email.strip() == "": | ||||
|             LOGGER.info( | ||||
|                 "Discarding notification as user has no email address", | ||||
| @ -487,17 +489,14 @@ class NotificationTransport(SerializerModel): | ||||
|             template_name="email/event_notification.html", | ||||
|             template_context=context, | ||||
|         ) | ||||
|         # Email is sent directly here, as the call to send() should have been from a task. | ||||
|         try: | ||||
|             from authentik.stages.email.tasks import send_mail | ||||
|  | ||||
|             return send_mail(mail.__dict__) | ||||
|         except (SMTPException, ConnectionError, OSError) as exc: | ||||
|             raise NotificationTransportError(exc) from exc | ||||
|         send_mail.send_with_options(args=(mail.__dict__,), rel_obj=self) | ||||
|         return [] | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> type[Serializer]: | ||||
|         from authentik.events.api.notification_transports import NotificationTransportSerializer | ||||
|         from authentik.events.api.notification_transports import ( | ||||
|             NotificationTransportSerializer, | ||||
|         ) | ||||
|  | ||||
|         return NotificationTransportSerializer | ||||
|  | ||||
| @ -547,7 +546,7 @@ class Notification(SerializerModel): | ||||
|         verbose_name_plural = _("Notifications") | ||||
|  | ||||
|  | ||||
| class NotificationRule(SerializerModel, PolicyBindingModel): | ||||
| class NotificationRule(TasksModel, SerializerModel, PolicyBindingModel): | ||||
|     """Decide when to create a Notification based on policies attached to this object.""" | ||||
|  | ||||
|     name = models.TextField(unique=True) | ||||
| @ -611,7 +610,9 @@ class NotificationWebhookMapping(PropertyMapping): | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> type[type[Serializer]]: | ||||
|         from authentik.events.api.notification_mappings import NotificationWebhookMappingSerializer | ||||
|         from authentik.events.api.notification_mappings import ( | ||||
|             NotificationWebhookMappingSerializer, | ||||
|         ) | ||||
|  | ||||
|         return NotificationWebhookMappingSerializer | ||||
|  | ||||
| @ -624,7 +625,7 @@ class NotificationWebhookMapping(PropertyMapping): | ||||
|  | ||||
|  | ||||
| class TaskStatus(models.TextChoices): | ||||
|     """Possible states of tasks""" | ||||
|     """DEPRECATED do not use""" | ||||
|  | ||||
|     UNKNOWN = "unknown" | ||||
|     SUCCESSFUL = "successful" | ||||
| @ -632,8 +633,8 @@ class TaskStatus(models.TextChoices): | ||||
|     ERROR = "error" | ||||
|  | ||||
|  | ||||
| class SystemTask(SerializerModel, ExpiringModel): | ||||
|     """Info about a system task running in the background along with details to restart the task""" | ||||
| class SystemTask(ExpiringModel): | ||||
|     """DEPRECATED do not use""" | ||||
|  | ||||
|     uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4) | ||||
|     name = models.TextField() | ||||
| @ -653,41 +654,13 @@ class SystemTask(SerializerModel, ExpiringModel): | ||||
|     task_call_args = models.JSONField(default=list) | ||||
|     task_call_kwargs = models.JSONField(default=dict) | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> type[Serializer]: | ||||
|         from authentik.events.api.tasks import SystemTaskSerializer | ||||
|  | ||||
|         return SystemTaskSerializer | ||||
|  | ||||
|     def update_metrics(self): | ||||
|         """Update prometheus metrics""" | ||||
|         # TODO: Deprecated metric - remove in 2024.2 or later | ||||
|         GAUGE_TASKS.labels( | ||||
|             tenant=connection.schema_name, | ||||
|             task_name=self.name, | ||||
|             task_uid=self.uid or "", | ||||
|             status=self.status.lower(), | ||||
|         ).set(self.duration) | ||||
|         SYSTEM_TASK_TIME.labels( | ||||
|             tenant=connection.schema_name, | ||||
|             task_name=self.name, | ||||
|             task_uid=self.uid or "", | ||||
|         ).observe(self.duration) | ||||
|         SYSTEM_TASK_STATUS.labels( | ||||
|             tenant=connection.schema_name, | ||||
|             task_name=self.name, | ||||
|             task_uid=self.uid or "", | ||||
|             status=self.status.lower(), | ||||
|         ).inc() | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         return f"System Task {self.name}" | ||||
|  | ||||
|     class Meta: | ||||
|         unique_together = (("name", "uid"),) | ||||
|         # Remove "add", "change" and "delete" permissions as those are not used | ||||
|         default_permissions = ["view"] | ||||
|         permissions = [("run_task", _("Run task"))] | ||||
|         default_permissions = () | ||||
|         permissions = () | ||||
|         verbose_name = _("System Task") | ||||
|         verbose_name_plural = _("System Tasks") | ||||
|         indexes = ExpiringModel.Meta.indexes | ||||
|  | ||||
| @ -1,13 +0,0 @@ | ||||
| """Event Settings""" | ||||
|  | ||||
| from celery.schedules import crontab | ||||
|  | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
|  | ||||
| CELERY_BEAT_SCHEDULE = { | ||||
|     "events_notification_cleanup": { | ||||
|         "task": "authentik.events.tasks.notification_cleanup", | ||||
|         "schedule": crontab(minute=fqdn_rand("notification_cleanup"), hour="*/8"), | ||||
|         "options": {"queue": "authentik_scheduled"}, | ||||
|     }, | ||||
| } | ||||
| @ -12,13 +12,10 @@ from rest_framework.request import Request | ||||
|  | ||||
| from authentik.core.models import AuthenticatedSession, User | ||||
| from authentik.core.signals import login_failed, password_changed | ||||
| from authentik.events.apps import SYSTEM_TASK_STATUS | ||||
| from authentik.events.models import Event, EventAction, SystemTask | ||||
| from authentik.events.tasks import event_notification_handler, gdpr_cleanup | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.flows.models import Stage | ||||
| from authentik.flows.planner import PLAN_CONTEXT_OUTPOST, PLAN_CONTEXT_SOURCE, FlowPlan | ||||
| from authentik.flows.views.executor import SESSION_KEY_PLAN | ||||
| from authentik.root.monitoring import monitoring_set | ||||
| from authentik.stages.invitation.models import Invitation | ||||
| from authentik.stages.invitation.signals import invitation_used | ||||
| from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS | ||||
| @ -114,19 +111,15 @@ def on_password_changed(sender, user: User, password: str, request: HttpRequest | ||||
| @receiver(post_save, sender=Event) | ||||
| def event_post_save_notification(sender, instance: Event, **_): | ||||
|     """Start task to check if any policies trigger an notification on this event""" | ||||
|     event_notification_handler.delay(instance.event_uuid.hex) | ||||
|     from authentik.events.tasks import event_trigger_dispatch | ||||
|  | ||||
|     event_trigger_dispatch.send(instance.event_uuid) | ||||
|  | ||||
|  | ||||
| @receiver(pre_delete, sender=User) | ||||
| def event_user_pre_delete_cleanup(sender, instance: User, **_): | ||||
|     """If gdpr_compliance is enabled, remove all the user's events""" | ||||
|     from authentik.events.tasks import gdpr_cleanup | ||||
|  | ||||
|     if get_current_tenant().gdpr_compliance: | ||||
|         gdpr_cleanup.delay(instance.pk) | ||||
|  | ||||
|  | ||||
| @receiver(monitoring_set) | ||||
| def monitoring_system_task(sender, **_): | ||||
|     """Update metrics when task is saved""" | ||||
|     SYSTEM_TASK_STATUS.clear() | ||||
|     for task in SystemTask.objects.all(): | ||||
|         task.update_metrics() | ||||
|         gdpr_cleanup.send(instance.pk) | ||||
|  | ||||
| @ -1,156 +0,0 @@ | ||||
| """Monitored tasks""" | ||||
|  | ||||
| from datetime import datetime, timedelta | ||||
| from time import perf_counter | ||||
| from typing import Any | ||||
|  | ||||
| from django.utils.timezone import now | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from structlog.stdlib import BoundLogger, get_logger | ||||
| from tenant_schemas_celery.task import TenantTask | ||||
|  | ||||
| from authentik.events.logs import LogEvent | ||||
| from authentik.events.models import Event, EventAction, TaskStatus | ||||
| from authentik.events.models import SystemTask as DBSystemTask | ||||
| from authentik.events.utils import sanitize_item | ||||
| from authentik.lib.utils.errors import exception_to_string | ||||
|  | ||||
|  | ||||
| class SystemTask(TenantTask): | ||||
|     """Task which can save its state to the cache""" | ||||
|  | ||||
|     logger: BoundLogger | ||||
|  | ||||
|     # For tasks that should only be listed if they failed, set this to False | ||||
|     save_on_success: bool | ||||
|  | ||||
|     _status: TaskStatus | ||||
|     _messages: list[LogEvent] | ||||
|  | ||||
|     _uid: str | None | ||||
|     # Precise start time from perf_counter | ||||
|     _start_precise: float | None = None | ||||
|     _start: datetime | None = None | ||||
|  | ||||
|     def __init__(self, *args, **kwargs) -> None: | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self._status = TaskStatus.SUCCESSFUL | ||||
|         self.save_on_success = True | ||||
|         self._uid = None | ||||
|         self._status = None | ||||
|         self._messages = [] | ||||
|         self.result_timeout_hours = 6 | ||||
|  | ||||
|     def set_uid(self, uid: str): | ||||
|         """Set UID, so in the case of an unexpected error its saved correctly""" | ||||
|         self._uid = uid | ||||
|  | ||||
|     def set_status(self, status: TaskStatus, *messages: LogEvent): | ||||
|         """Set result for current run, will overwrite previous result.""" | ||||
|         self._status = status | ||||
|         self._messages = list(messages) | ||||
|         for idx, msg in enumerate(self._messages): | ||||
|             if not isinstance(msg, LogEvent): | ||||
|                 self._messages[idx] = LogEvent(msg, logger=self.__name__, log_level="info") | ||||
|  | ||||
|     def set_error(self, exception: Exception, *messages: LogEvent): | ||||
|         """Set result to error and save exception""" | ||||
|         self._status = TaskStatus.ERROR | ||||
|         self._messages = list(messages) | ||||
|         self._messages.extend( | ||||
|             [LogEvent(exception_to_string(exception), logger=self.__name__, log_level="error")] | ||||
|         ) | ||||
|  | ||||
|     def before_start(self, task_id, args, kwargs): | ||||
|         self._start_precise = perf_counter() | ||||
|         self._start = now() | ||||
|         self.logger = get_logger().bind(task_id=task_id) | ||||
|         return super().before_start(task_id, args, kwargs) | ||||
|  | ||||
|     def db(self) -> DBSystemTask | None: | ||||
|         """Get DB object for latest task""" | ||||
|         return DBSystemTask.objects.filter( | ||||
|             name=self.__name__, | ||||
|             uid=self._uid, | ||||
|         ).first() | ||||
|  | ||||
|     def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo): | ||||
|         super().after_return(status, retval, task_id, args, kwargs, einfo=einfo) | ||||
|         if not self._status: | ||||
|             return | ||||
|         if self._status == TaskStatus.SUCCESSFUL and not self.save_on_success: | ||||
|             DBSystemTask.objects.filter( | ||||
|                 name=self.__name__, | ||||
|                 uid=self._uid, | ||||
|             ).delete() | ||||
|             return | ||||
|         DBSystemTask.objects.update_or_create( | ||||
|             name=self.__name__, | ||||
|             uid=self._uid, | ||||
|             defaults={ | ||||
|                 "description": self.__doc__, | ||||
|                 "start_timestamp": self._start or now(), | ||||
|                 "finish_timestamp": now(), | ||||
|                 "duration": max(perf_counter() - self._start_precise, 0), | ||||
|                 "task_call_module": self.__module__, | ||||
|                 "task_call_func": self.__name__, | ||||
|                 "task_call_args": sanitize_item(args), | ||||
|                 "task_call_kwargs": sanitize_item(kwargs), | ||||
|                 "status": self._status, | ||||
|                 "messages": sanitize_item(self._messages), | ||||
|                 "expires": now() + timedelta(hours=self.result_timeout_hours), | ||||
|                 "expiring": True, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def on_failure(self, exc, task_id, args, kwargs, einfo): | ||||
|         super().on_failure(exc, task_id, args, kwargs, einfo=einfo) | ||||
|         if not self._status: | ||||
|             self.set_error(exc) | ||||
|         DBSystemTask.objects.update_or_create( | ||||
|             name=self.__name__, | ||||
|             uid=self._uid, | ||||
|             defaults={ | ||||
|                 "description": self.__doc__, | ||||
|                 "start_timestamp": self._start or now(), | ||||
|                 "finish_timestamp": now(), | ||||
|                 "duration": max(perf_counter() - self._start_precise, 0), | ||||
|                 "task_call_module": self.__module__, | ||||
|                 "task_call_func": self.__name__, | ||||
|                 "task_call_args": sanitize_item(args), | ||||
|                 "task_call_kwargs": sanitize_item(kwargs), | ||||
|                 "status": self._status, | ||||
|                 "messages": sanitize_item(self._messages), | ||||
|                 "expires": now() + timedelta(hours=self.result_timeout_hours + 3), | ||||
|                 "expiring": True, | ||||
|             }, | ||||
|         ) | ||||
|         Event.new( | ||||
|             EventAction.SYSTEM_TASK_EXCEPTION, | ||||
|             message=f"Task {self.__name__} encountered an error: {exception_to_string(exc)}", | ||||
|         ).save() | ||||
|  | ||||
|     def run(self, *args, **kwargs): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|  | ||||
| def prefill_task(func): | ||||
|     """Ensure a task's details are always in cache, so it can always be triggered via API""" | ||||
|     _prefill_tasks.append( | ||||
|         DBSystemTask( | ||||
|             name=func.__name__, | ||||
|             description=func.__doc__, | ||||
|             start_timestamp=now(), | ||||
|             finish_timestamp=now(), | ||||
|             status=TaskStatus.UNKNOWN, | ||||
|             messages=sanitize_item([_("Task has not been run yet.")]), | ||||
|             task_call_module=func.__module__, | ||||
|             task_call_func=func.__name__, | ||||
|             expiring=False, | ||||
|             duration=0, | ||||
|         ) | ||||
|     ) | ||||
|     return func | ||||
|  | ||||
|  | ||||
| _prefill_tasks = [] | ||||
| @ -1,41 +1,49 @@ | ||||
| """Event notification tasks""" | ||||
|  | ||||
| from uuid import UUID | ||||
|  | ||||
| from django.db.models.query_utils import Q | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from django_dramatiq_postgres.middleware import CurrentTask | ||||
| from dramatiq.actor import actor | ||||
| from guardian.shortcuts import get_anonymous_user | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.expression.exceptions import PropertyMappingExpressionException | ||||
| from authentik.core.models import User | ||||
| from authentik.events.models import ( | ||||
|     Event, | ||||
|     Notification, | ||||
|     NotificationRule, | ||||
|     NotificationTransport, | ||||
|     NotificationTransportError, | ||||
|     TaskStatus, | ||||
| ) | ||||
| from authentik.events.system_tasks import SystemTask, prefill_task | ||||
| from authentik.policies.engine import PolicyEngine | ||||
| from authentik.policies.models import PolicyBinding, PolicyEngineMode | ||||
| from authentik.root.celery import CELERY_APP | ||||
| from authentik.tasks.models import Task | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task() | ||||
| def event_notification_handler(event_uuid: str): | ||||
|     """Start task for each trigger definition""" | ||||
| @actor(description=_("Dispatch new event notifications.")) | ||||
| def event_trigger_dispatch(event_uuid: UUID): | ||||
|     for trigger in NotificationRule.objects.all(): | ||||
|         event_trigger_handler.apply_async(args=[event_uuid, trigger.name], queue="authentik_events") | ||||
|         event_trigger_handler.send_with_options(args=(event_uuid, trigger.name), rel_obj=trigger) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task() | ||||
| def event_trigger_handler(event_uuid: str, trigger_name: str): | ||||
| @actor( | ||||
|     description=_( | ||||
|         "Check if policies attached to NotificationRule match event " | ||||
|         "and dispatch notification tasks." | ||||
|     ) | ||||
| ) | ||||
| def event_trigger_handler(event_uuid: UUID, trigger_name: str): | ||||
|     """Check if policies attached to NotificationRule match event""" | ||||
|     self: Task = CurrentTask.get_task() | ||||
|  | ||||
|     event: Event = Event.objects.filter(event_uuid=event_uuid).first() | ||||
|     if not event: | ||||
|         LOGGER.warning("event doesn't exist yet or anymore", event_uuid=event_uuid) | ||||
|         self.warning("event doesn't exist yet or anymore", event_uuid=event_uuid) | ||||
|         return | ||||
|  | ||||
|     trigger: NotificationRule | None = NotificationRule.objects.filter(name=trigger_name).first() | ||||
|     if not trigger: | ||||
|         return | ||||
| @ -70,57 +78,46 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): | ||||
|  | ||||
|     LOGGER.debug("e(trigger): event trigger matched", trigger=trigger) | ||||
|     # Create the notification objects | ||||
|     count = 0 | ||||
|     for transport in trigger.transports.all(): | ||||
|         for user in trigger.destination_users(event): | ||||
|             LOGGER.debug("created notification") | ||||
|             notification_transport.apply_async( | ||||
|                 args=[ | ||||
|             notification_transport.send_with_options( | ||||
|                 args=( | ||||
|                     transport.pk, | ||||
|                     str(event.pk), | ||||
|                     event.pk, | ||||
|                     user.pk, | ||||
|                     str(trigger.pk), | ||||
|                 ], | ||||
|                 queue="authentik_events", | ||||
|                     trigger.pk, | ||||
|                 ), | ||||
|                 rel_obj=transport, | ||||
|             ) | ||||
|             count += 1 | ||||
|             if transport.send_once: | ||||
|                 break | ||||
|     self.info(f"Created {count} notification tasks") | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task( | ||||
|     bind=True, | ||||
|     autoretry_for=(NotificationTransportError,), | ||||
|     retry_backoff=True, | ||||
|     base=SystemTask, | ||||
| ) | ||||
| def notification_transport( | ||||
|     self: SystemTask, transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str | ||||
| ): | ||||
| @actor(description=_("Send notification.")) | ||||
| def notification_transport(transport_pk: int, event_pk: str, user_pk: int, trigger_pk: str): | ||||
|     """Send notification over specified transport""" | ||||
|     self.save_on_success = False | ||||
|     try: | ||||
|         event = Event.objects.filter(pk=event_pk).first() | ||||
|         if not event: | ||||
|             return | ||||
|         user = User.objects.filter(pk=user_pk).first() | ||||
|         if not user: | ||||
|             return | ||||
|         trigger = NotificationRule.objects.filter(pk=trigger_pk).first() | ||||
|         if not trigger: | ||||
|             return | ||||
|         notification = Notification( | ||||
|             severity=trigger.severity, body=event.summary, event=event, user=user | ||||
|         ) | ||||
|         transport = NotificationTransport.objects.filter(pk=transport_pk).first() | ||||
|         if not transport: | ||||
|             return | ||||
|         transport.send(notification) | ||||
|         self.set_status(TaskStatus.SUCCESSFUL) | ||||
|     except (NotificationTransportError, PropertyMappingExpressionException) as exc: | ||||
|         self.set_error(exc) | ||||
|         raise exc | ||||
|     event = Event.objects.filter(pk=event_pk).first() | ||||
|     if not event: | ||||
|         return | ||||
|     user = User.objects.filter(pk=user_pk).first() | ||||
|     if not user: | ||||
|         return | ||||
|     trigger = NotificationRule.objects.filter(pk=trigger_pk).first() | ||||
|     if not trigger: | ||||
|         return | ||||
|     notification = Notification( | ||||
|         severity=trigger.severity, body=event.summary, event=event, user=user | ||||
|     ) | ||||
|     transport: NotificationTransport = NotificationTransport.objects.filter(pk=transport_pk).first() | ||||
|     if not transport: | ||||
|         return | ||||
|     transport.send(notification) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task() | ||||
| @actor(description=_("Cleanup events for GDPR compliance.")) | ||||
| def gdpr_cleanup(user_pk: int): | ||||
|     """cleanup events from gdpr_compliance""" | ||||
|     events = Event.objects.filter(user__pk=user_pk) | ||||
| @ -128,12 +125,12 @@ def gdpr_cleanup(user_pk: int): | ||||
|     events.delete() | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(bind=True, base=SystemTask) | ||||
| @prefill_task | ||||
| def notification_cleanup(self: SystemTask): | ||||
| @actor(description=_("Cleanup seen notifications and notifications whose event expired.")) | ||||
| def notification_cleanup(): | ||||
|     """Cleanup seen notifications and notifications whose event expired.""" | ||||
|     self: Task = CurrentTask.get_task() | ||||
|     notifications = Notification.objects.filter(Q(event=None) | Q(seen=True)) | ||||
|     amount = notifications.count() | ||||
|     notifications.delete() | ||||
|     LOGGER.debug("Expired notifications", amount=amount) | ||||
|     self.set_status(TaskStatus.SUCCESSFUL, f"Expired {amount} Notifications") | ||||
|     self.info(f"Expired {amount} Notifications") | ||||
|  | ||||
| @ -1,103 +0,0 @@ | ||||
| """Test Monitored tasks""" | ||||
|  | ||||
| from json import loads | ||||
|  | ||||
| from django.urls import reverse | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.tasks import clean_expired_models | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
| from authentik.events.models import SystemTask as DBSystemTask | ||||
| from authentik.events.models import TaskStatus | ||||
| from authentik.events.system_tasks import SystemTask | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.root.celery import CELERY_APP | ||||
|  | ||||
|  | ||||
| class TestSystemTasks(APITestCase): | ||||
|     """Test Monitored tasks""" | ||||
|  | ||||
|     def setUp(self): | ||||
|         super().setUp() | ||||
|         self.user = create_test_admin_user() | ||||
|         self.client.force_login(self.user) | ||||
|  | ||||
|     def test_failed_successful_remove_state(self): | ||||
|         """Test that a task with `save_on_success` set to `False` that failed saves | ||||
|         a state, and upon successful completion will delete the state""" | ||||
|         should_fail = True | ||||
|         uid = generate_id() | ||||
|  | ||||
|         @CELERY_APP.task( | ||||
|             bind=True, | ||||
|             base=SystemTask, | ||||
|         ) | ||||
|         def test_task(self: SystemTask): | ||||
|             self.save_on_success = False | ||||
|             self.set_uid(uid) | ||||
|             self.set_status(TaskStatus.ERROR if should_fail else TaskStatus.SUCCESSFUL) | ||||
|  | ||||
|         # First test successful run | ||||
|         should_fail = False | ||||
|         test_task.delay().get() | ||||
|         self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first()) | ||||
|  | ||||
|         # Then test failed | ||||
|         should_fail = True | ||||
|         test_task.delay().get() | ||||
|         task = DBSystemTask.objects.filter(name="test_task", uid=uid).first() | ||||
|         self.assertEqual(task.status, TaskStatus.ERROR) | ||||
|  | ||||
|         # Then after that, the state should be removed | ||||
|         should_fail = False | ||||
|         test_task.delay().get() | ||||
|         self.assertIsNone(DBSystemTask.objects.filter(name="test_task", uid=uid).first()) | ||||
|  | ||||
|     def test_tasks(self): | ||||
|         """Test Task API""" | ||||
|         clean_expired_models.delay().get() | ||||
|         response = self.client.get(reverse("authentik_api:systemtask-list")) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         body = loads(response.content) | ||||
|         self.assertTrue(any(task["name"] == "clean_expired_models" for task in body["results"])) | ||||
|  | ||||
|     def test_tasks_single(self): | ||||
|         """Test Task API (read single)""" | ||||
|         clean_expired_models.delay().get() | ||||
|         task = DBSystemTask.objects.filter(name="clean_expired_models").first() | ||||
|         response = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_api:systemtask-detail", | ||||
|                 kwargs={"pk": str(task.pk)}, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         body = loads(response.content) | ||||
|         self.assertEqual(body["status"], TaskStatus.SUCCESSFUL.value) | ||||
|         self.assertEqual(body["name"], "clean_expired_models") | ||||
|         response = self.client.get( | ||||
|             reverse("authentik_api:systemtask-detail", kwargs={"pk": "qwerqwer"}) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 404) | ||||
|  | ||||
|     def test_tasks_run(self): | ||||
|         """Test Task API (run)""" | ||||
|         clean_expired_models.delay().get() | ||||
|         task = DBSystemTask.objects.filter(name="clean_expired_models").first() | ||||
|         response = self.client.post( | ||||
|             reverse( | ||||
|                 "authentik_api:systemtask-run", | ||||
|                 kwargs={"pk": str(task.pk)}, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 204) | ||||
|  | ||||
|     def test_tasks_run_404(self): | ||||
|         """Test Task API (run, 404)""" | ||||
|         response = self.client.post( | ||||
|             reverse( | ||||
|                 "authentik_api:systemtask-run", | ||||
|                 kwargs={"pk": "qwerqewrqrqewrqewr"}, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 404) | ||||
| @ -5,13 +5,11 @@ from authentik.events.api.notification_mappings import NotificationWebhookMappin | ||||
| from authentik.events.api.notification_rules import NotificationRuleViewSet | ||||
| from authentik.events.api.notification_transports import NotificationTransportViewSet | ||||
| from authentik.events.api.notifications import NotificationViewSet | ||||
| from authentik.events.api.tasks import SystemTaskViewSet | ||||
|  | ||||
| api_urlpatterns = [ | ||||
|     ("events/events", EventViewSet), | ||||
|     ("events/notifications", NotificationViewSet), | ||||
|     ("events/transports", NotificationTransportViewSet), | ||||
|     ("events/rules", NotificationRuleViewSet), | ||||
|     ("events/system_tasks", SystemTaskViewSet), | ||||
|     ("propertymappings/notification", NotificationWebhookMappingViewSet), | ||||
| ] | ||||
|  | ||||
| @ -41,6 +41,7 @@ REDIS_ENV_KEYS = [ | ||||
| # Old key -> new key | ||||
| DEPRECATIONS = { | ||||
|     "geoip": "events.context_processors.geoip", | ||||
|     "worker.concurrency": "worker.processes", | ||||
|     "redis.broker_url": "broker.url", | ||||
|     "redis.broker_transport_options": "broker.transport_options", | ||||
|     "redis.cache_timeout": "cache.timeout", | ||||
|  | ||||
| @ -21,6 +21,10 @@ def start_debug_server(**kwargs) -> bool: | ||||
|  | ||||
|     listen: str = CONFIG.get("listen.listen_debug_py", "127.0.0.1:9901") | ||||
|     host, _, port = listen.rpartition(":") | ||||
|     debugpy.listen((host, int(port)), **kwargs)  # nosec | ||||
|     try: | ||||
|         debugpy.listen((host, int(port)), **kwargs)  # nosec | ||||
|     except RuntimeError: | ||||
|         LOGGER.warning("Could not start debug server. Continuing without") | ||||
|         return False | ||||
|     LOGGER.debug("Starting debug server", host=host, port=port) | ||||
|     return True | ||||
|  | ||||
| @ -8,9 +8,9 @@ | ||||
| # make gen-dev-config | ||||
| # ``` | ||||
| # | ||||
| # You may edit the generated file to override the configuration below.   | ||||
| # You may edit the generated file to override the configuration below. | ||||
| # | ||||
| # When making modifying the default configuration file,  | ||||
| # When making modifying the default configuration file, | ||||
| # ensure that the corresponding documentation is updated to match. | ||||
| # | ||||
| # @see {@link ../../website/docs/install-config/configuration/configuration.mdx Configuration documentation} for more information. | ||||
| @ -157,7 +157,14 @@ web: | ||||
|   path: / | ||||
|  | ||||
| worker: | ||||
|   concurrency: 2 | ||||
|   processes: 2 | ||||
|   threads: 1 | ||||
|   consumer_listen_timeout: "seconds=30" | ||||
|   task_max_retries: 20 | ||||
|   task_default_time_limit: "minutes=10" | ||||
|   task_purge_interval: "days=1" | ||||
|   task_expiration: "days=30" | ||||
|   scheduler_interval: "seconds=60" | ||||
|  | ||||
| storage: | ||||
|   media: | ||||
|  | ||||
| @ -88,7 +88,6 @@ def get_logger_config(): | ||||
|         "authentik": global_level, | ||||
|         "django": "WARNING", | ||||
|         "django.request": "ERROR", | ||||
|         "celery": "WARNING", | ||||
|         "selenium": "WARNING", | ||||
|         "docker": "WARNING", | ||||
|         "urllib3": "WARNING", | ||||
|  | ||||
| @ -3,8 +3,6 @@ | ||||
| from asyncio.exceptions import CancelledError | ||||
| from typing import Any | ||||
|  | ||||
| from billiard.exceptions import SoftTimeLimitExceeded, WorkerLostError | ||||
| from celery.exceptions import CeleryError | ||||
| from channels_redis.core import ChannelFull | ||||
| from django.conf import settings | ||||
| from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation, ValidationError | ||||
| @ -22,7 +20,6 @@ from sentry_sdk import HttpTransport, get_current_scope | ||||
| from sentry_sdk import init as sentry_sdk_init | ||||
| from sentry_sdk.api import set_tag | ||||
| from sentry_sdk.integrations.argv import ArgvIntegration | ||||
| from sentry_sdk.integrations.celery import CeleryIntegration | ||||
| from sentry_sdk.integrations.django import DjangoIntegration | ||||
| from sentry_sdk.integrations.redis import RedisIntegration | ||||
| from sentry_sdk.integrations.socket import SocketIntegration | ||||
| @ -71,10 +68,6 @@ ignored_classes = ( | ||||
|     LocalProtocolError, | ||||
|     # rest_framework error | ||||
|     APIException, | ||||
|     # celery errors | ||||
|     WorkerLostError, | ||||
|     CeleryError, | ||||
|     SoftTimeLimitExceeded, | ||||
|     # custom baseclass | ||||
|     SentryIgnoredException, | ||||
|     # ldap errors | ||||
| @ -115,7 +108,6 @@ def sentry_init(**sentry_init_kwargs): | ||||
|             ArgvIntegration(), | ||||
|             StdlibIntegration(), | ||||
|             DjangoIntegration(transaction_style="function_name", cache_spans=True), | ||||
|             CeleryIntegration(), | ||||
|             RedisIntegration(), | ||||
|             ThreadingIntegration(propagate_hub=True), | ||||
|             SocketIntegration(), | ||||
| @ -160,14 +152,11 @@ def before_send(event: dict, hint: dict) -> dict | None: | ||||
|             return None | ||||
|     if "logger" in event: | ||||
|         if event["logger"] in [ | ||||
|             "kombu", | ||||
|             "asyncio", | ||||
|             "multiprocessing", | ||||
|             "django_redis", | ||||
|             "django.security.DisallowedHost", | ||||
|             "django_redis.cache", | ||||
|             "celery.backends.redis", | ||||
|             "celery.worker", | ||||
|             "paramiko.transport", | ||||
|         ]: | ||||
|             return None | ||||
|  | ||||
							
								
								
									
										12
									
								
								authentik/lib/sync/api.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								authentik/lib/sync/api.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,12 @@ | ||||
| from rest_framework.fields import BooleanField, ChoiceField, DateTimeField | ||||
|  | ||||
| from authentik.core.api.utils import PassiveSerializer | ||||
| from authentik.tasks.models import TaskStatus | ||||
|  | ||||
|  | ||||
| class SyncStatusSerializer(PassiveSerializer): | ||||
|     """Provider/source sync status""" | ||||
|  | ||||
|     is_running = BooleanField() | ||||
|     last_successful_sync = DateTimeField(required=False) | ||||
|     last_sync_status = ChoiceField(required=False, choices=TaskStatus.choices) | ||||
| @ -1,7 +1,7 @@ | ||||
| """Sync constants""" | ||||
|  | ||||
| PAGE_SIZE = 100 | ||||
| PAGE_TIMEOUT = 60 * 60 * 0.5  # Half an hour | ||||
| PAGE_TIMEOUT_MS = 60 * 60 * 0.5 * 1000  # Half an hour | ||||
| HTTP_CONFLICT = 409 | ||||
| HTTP_NO_CONTENT = 204 | ||||
| HTTP_SERVICE_UNAVAILABLE = 503 | ||||
|  | ||||
| @ -1,7 +1,5 @@ | ||||
| from celery import Task | ||||
| from django.utils.text import slugify | ||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema | ||||
| from guardian.shortcuts import get_objects_for_user | ||||
| from dramatiq.actor import Actor | ||||
| from drf_spectacular.utils import extend_schema | ||||
| from rest_framework.decorators import action | ||||
| from rest_framework.fields import BooleanField, CharField, ChoiceField | ||||
| from rest_framework.request import Request | ||||
| @ -9,18 +7,12 @@ from rest_framework.response import Response | ||||
|  | ||||
| from authentik.core.api.utils import ModelSerializer, PassiveSerializer | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.events.api.tasks import SystemTaskSerializer | ||||
| from authentik.events.logs import LogEvent, LogEventSerializer | ||||
| from authentik.events.logs import LogEventSerializer | ||||
| from authentik.lib.sync.api import SyncStatusSerializer | ||||
| from authentik.lib.sync.outgoing.models import OutgoingSyncProvider | ||||
| from authentik.lib.utils.reflection import class_to_path | ||||
| from authentik.rbac.filters import ObjectFilter | ||||
|  | ||||
|  | ||||
| class SyncStatusSerializer(PassiveSerializer): | ||||
|     """Provider sync status""" | ||||
|  | ||||
|     is_running = BooleanField(read_only=True) | ||||
|     tasks = SystemTaskSerializer(many=True, read_only=True) | ||||
| from authentik.tasks.models import Task, TaskStatus | ||||
|  | ||||
|  | ||||
| class SyncObjectSerializer(PassiveSerializer): | ||||
| @ -45,15 +37,10 @@ class SyncObjectResultSerializer(PassiveSerializer): | ||||
| class OutgoingSyncProviderStatusMixin: | ||||
|     """Common API Endpoints for Outgoing sync providers""" | ||||
|  | ||||
|     sync_single_task: type[Task] = None | ||||
|     sync_objects_task: type[Task] = None | ||||
|     sync_task: Actor | ||||
|     sync_objects_task: Actor | ||||
|  | ||||
|     @extend_schema( | ||||
|         responses={ | ||||
|             200: SyncStatusSerializer(), | ||||
|             404: OpenApiResponse(description="Task not found"), | ||||
|         } | ||||
|     ) | ||||
|     @extend_schema(responses={200: SyncStatusSerializer()}) | ||||
|     @action( | ||||
|         methods=["GET"], | ||||
|         detail=True, | ||||
| @ -64,18 +51,39 @@ class OutgoingSyncProviderStatusMixin: | ||||
|     def sync_status(self, request: Request, pk: int) -> Response: | ||||
|         """Get provider's sync status""" | ||||
|         provider: OutgoingSyncProvider = self.get_object() | ||||
|         tasks = list( | ||||
|             get_objects_for_user(request.user, "authentik_events.view_systemtask").filter( | ||||
|                 name=self.sync_single_task.__name__, | ||||
|                 uid=slugify(provider.name), | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         status = {} | ||||
|  | ||||
|         with provider.sync_lock as lock_acquired: | ||||
|             status = { | ||||
|                 "tasks": tasks, | ||||
|                 # If we could not acquire the lock, it means a task is using it, and thus is running | ||||
|                 "is_running": not lock_acquired, | ||||
|             } | ||||
|             # If we could not acquire the lock, it means a task is using it, and thus is running | ||||
|             status["is_running"] = not lock_acquired | ||||
|  | ||||
|         sync_schedule = None | ||||
|         for schedule in provider.schedules.all(): | ||||
|             if schedule.actor_name == self.sync_task.actor_name: | ||||
|                 sync_schedule = schedule | ||||
|  | ||||
|         if not sync_schedule: | ||||
|             return Response(SyncStatusSerializer(status).data) | ||||
|  | ||||
|         last_task: Task = ( | ||||
|             sync_schedule.tasks.exclude( | ||||
|                 aggregated_status__in=(TaskStatus.CONSUMED, TaskStatus.QUEUED) | ||||
|             ) | ||||
|             .order_by("-mtime") | ||||
|             .first() | ||||
|         ) | ||||
|         last_successful_task: Task = ( | ||||
|             sync_schedule.tasks.filter(aggregated_status__in=(TaskStatus.DONE, TaskStatus.INFO)) | ||||
|             .order_by("-mtime") | ||||
|             .first() | ||||
|         ) | ||||
|  | ||||
|         if last_task: | ||||
|             status["last_sync_status"] = last_task.aggregated_status | ||||
|         if last_successful_task: | ||||
|             status["last_successful_sync"] = last_successful_task.mtime | ||||
|  | ||||
|         return Response(SyncStatusSerializer(status).data) | ||||
|  | ||||
|     @extend_schema( | ||||
| @ -94,14 +102,20 @@ class OutgoingSyncProviderStatusMixin: | ||||
|         provider: OutgoingSyncProvider = self.get_object() | ||||
|         params = SyncObjectSerializer(data=request.data) | ||||
|         params.is_valid(raise_exception=True) | ||||
|         res: list[LogEvent] = self.sync_objects_task.delay( | ||||
|             params.validated_data["sync_object_model"], | ||||
|             page=1, | ||||
|             provider_pk=provider.pk, | ||||
|             pk=params.validated_data["sync_object_id"], | ||||
|             override_dry_run=params.validated_data["override_dry_run"], | ||||
|         ).get() | ||||
|         return Response(SyncObjectResultSerializer(instance={"messages": res}).data) | ||||
|         msg = self.sync_objects_task.send_with_options( | ||||
|             kwargs={ | ||||
|                 "object_type": params.validated_data["sync_object_model"], | ||||
|                 "page": 1, | ||||
|                 "provider_pk": provider.pk, | ||||
|                 "override_dry_run": params.validated_data["override_dry_run"], | ||||
|                 "pk": params.validated_data["sync_object_id"], | ||||
|             }, | ||||
|             rel_obj=provider, | ||||
|         ) | ||||
|         msg.get_result(block=True) | ||||
|         task: Task = msg.options["task"] | ||||
|         task.refresh_from_db() | ||||
|         return Response(SyncObjectResultSerializer(instance={"messages": task._messages}).data) | ||||
|  | ||||
|  | ||||
| class OutgoingSyncConnectionCreateMixin: | ||||
|  | ||||
| @ -1,12 +1,18 @@ | ||||
| from typing import Any, Self | ||||
|  | ||||
| import pglock | ||||
| from django.core.paginator import Paginator | ||||
| from django.db import connection, models | ||||
| from django.db.models import Model, QuerySet, TextChoices | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from dramatiq.actor import Actor | ||||
|  | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT_MS | ||||
| from authentik.lib.sync.outgoing.base import BaseOutgoingSyncClient | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
| from authentik.tasks.schedules.lib import ScheduleSpec | ||||
| from authentik.tasks.schedules.models import ScheduledModel | ||||
|  | ||||
|  | ||||
| class OutgoingSyncDeleteAction(TextChoices): | ||||
| @ -18,7 +24,7 @@ class OutgoingSyncDeleteAction(TextChoices): | ||||
|     SUSPEND = "suspend" | ||||
|  | ||||
|  | ||||
| class OutgoingSyncProvider(Model): | ||||
| class OutgoingSyncProvider(ScheduledModel, Model): | ||||
|     """Base abstract models for providers implementing outgoing sync""" | ||||
|  | ||||
|     dry_run = models.BooleanField( | ||||
| @ -39,6 +45,19 @@ class OutgoingSyncProvider(Model): | ||||
|     def get_object_qs[T: User | Group](self, type: type[T]) -> QuerySet[T]: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def get_paginator[T: User | Group](self, type: type[T]) -> Paginator: | ||||
|         return Paginator(self.get_object_qs(type), PAGE_SIZE) | ||||
|  | ||||
|     def get_object_sync_time_limit_ms[T: User | Group](self, type: type[T]) -> int: | ||||
|         num_pages: int = self.get_paginator(type).num_pages | ||||
|         return int(num_pages * PAGE_TIMEOUT_MS * 1.5) | ||||
|  | ||||
|     def get_sync_time_limit_ms(self) -> int: | ||||
|         return int( | ||||
|             (self.get_object_sync_time_limit_ms(User) + self.get_object_sync_time_limit_ms(Group)) | ||||
|             * 1.5 | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def sync_lock(self) -> pglock.advisory: | ||||
|         """Postgres lock for syncing to prevent multiple parallel syncs happening""" | ||||
| @ -47,3 +66,22 @@ class OutgoingSyncProvider(Model): | ||||
|             timeout=0, | ||||
|             side_effect=pglock.Return, | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def sync_actor(self) -> Actor: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     @property | ||||
|     def schedule_specs(self) -> list[ScheduleSpec]: | ||||
|         return [ | ||||
|             ScheduleSpec( | ||||
|                 actor=self.sync_actor, | ||||
|                 uid=self.pk, | ||||
|                 args=(self.pk,), | ||||
|                 options={ | ||||
|                     "time_limit": self.get_sync_time_limit_ms(), | ||||
|                 }, | ||||
|                 send_on_save=True, | ||||
|                 crontab=f"{fqdn_rand(self.pk)} */4 * * *", | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
| @ -1,12 +1,8 @@ | ||||
| from collections.abc import Callable | ||||
|  | ||||
| from django.core.paginator import Paginator | ||||
| from django.db.models import Model | ||||
| from django.db.models.query import Q | ||||
| from django.db.models.signals import m2m_changed, post_save, pre_delete | ||||
| from dramatiq.actor import Actor | ||||
|  | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT | ||||
| from authentik.lib.sync.outgoing.base import Direction | ||||
| from authentik.lib.sync.outgoing.models import OutgoingSyncProvider | ||||
| from authentik.lib.utils.reflection import class_to_path | ||||
| @ -14,45 +10,30 @@ from authentik.lib.utils.reflection import class_to_path | ||||
|  | ||||
| def register_signals( | ||||
|     provider_type: type[OutgoingSyncProvider], | ||||
|     task_sync_single: Callable[[int], None], | ||||
|     task_sync_direct: Callable[[int], None], | ||||
|     task_sync_m2m: Callable[[int], None], | ||||
|     task_sync_direct_dispatch: Actor[[str, str | int, str], None], | ||||
|     task_sync_m2m_dispatch: Actor[[str, str, list[str], bool], None], | ||||
| ): | ||||
|     """Register sync signals""" | ||||
|     uid = class_to_path(provider_type) | ||||
|  | ||||
|     def post_save_provider(sender: type[Model], instance: OutgoingSyncProvider, created: bool, **_): | ||||
|         """Trigger sync when Provider is saved""" | ||||
|         users_paginator = Paginator(instance.get_object_qs(User), PAGE_SIZE) | ||||
|         groups_paginator = Paginator(instance.get_object_qs(Group), PAGE_SIZE) | ||||
|         soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT | ||||
|         time_limit = soft_time_limit * 1.5 | ||||
|         task_sync_single.apply_async( | ||||
|             (instance.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit) | ||||
|         ) | ||||
|  | ||||
|     post_save.connect(post_save_provider, provider_type, dispatch_uid=uid, weak=False) | ||||
|  | ||||
|     def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_): | ||||
|         """Post save handler""" | ||||
|         if not provider_type.objects.filter( | ||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False) | ||||
|         ).exists(): | ||||
|             return | ||||
|         task_sync_direct.delay(class_to_path(instance.__class__), instance.pk, Direction.add.value) | ||||
|         task_sync_direct_dispatch.send( | ||||
|             class_to_path(instance.__class__), | ||||
|             instance.pk, | ||||
|             Direction.add.value, | ||||
|         ) | ||||
|  | ||||
|     post_save.connect(model_post_save, User, dispatch_uid=uid, weak=False) | ||||
|     post_save.connect(model_post_save, Group, dispatch_uid=uid, weak=False) | ||||
|  | ||||
|     def model_pre_delete(sender: type[Model], instance: User | Group, **_): | ||||
|         """Pre-delete handler""" | ||||
|         if not provider_type.objects.filter( | ||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False) | ||||
|         ).exists(): | ||||
|             return | ||||
|         task_sync_direct.delay( | ||||
|             class_to_path(instance.__class__), instance.pk, Direction.remove.value | ||||
|         ).get(propagate=False) | ||||
|         task_sync_direct_dispatch.send( | ||||
|             class_to_path(instance.__class__), | ||||
|             instance.pk, | ||||
|             Direction.remove.value, | ||||
|         ) | ||||
|  | ||||
|     pre_delete.connect(model_pre_delete, User, dispatch_uid=uid, weak=False) | ||||
|     pre_delete.connect(model_pre_delete, Group, dispatch_uid=uid, weak=False) | ||||
| @ -63,16 +44,6 @@ def register_signals( | ||||
|         """Sync group membership""" | ||||
|         if action not in ["post_add", "post_remove"]: | ||||
|             return | ||||
|         if not provider_type.objects.filter( | ||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False) | ||||
|         ).exists(): | ||||
|             return | ||||
|         # reverse: instance is a Group, pk_set is a list of user pks | ||||
|         # non-reverse: instance is a User, pk_set is a list of groups | ||||
|         if reverse: | ||||
|             task_sync_m2m.delay(str(instance.pk), action, list(pk_set)) | ||||
|         else: | ||||
|             for group_pk in pk_set: | ||||
|                 task_sync_m2m.delay(group_pk, action, [instance.pk]) | ||||
|         task_sync_m2m_dispatch.send(instance.pk, action, list(pk_set), reverse) | ||||
|  | ||||
|     m2m_changed.connect(model_m2m_changed, User.ak_groups.through, dispatch_uid=uid, weak=False) | ||||
|  | ||||
| @ -1,23 +1,17 @@ | ||||
| from collections.abc import Callable | ||||
| from dataclasses import asdict | ||||
|  | ||||
| from celery import group | ||||
| from celery.exceptions import Retry | ||||
| from celery.result import allow_join_result | ||||
| from django.core.paginator import Paginator | ||||
| from django.db.models import Model, QuerySet | ||||
| from django.db.models.query import Q | ||||
| from django.utils.text import slugify | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from django_dramatiq_postgres.middleware import CurrentTask | ||||
| from dramatiq.actor import Actor | ||||
| from dramatiq.composition import group | ||||
| from dramatiq.errors import Retry | ||||
| from structlog.stdlib import BoundLogger, get_logger | ||||
|  | ||||
| from authentik.core.expression.exceptions import SkipObjectException | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.events.logs import LogEvent | ||||
| from authentik.events.models import TaskStatus | ||||
| from authentik.events.system_tasks import SystemTask | ||||
| from authentik.events.utils import sanitize_item | ||||
| from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT | ||||
| from authentik.lib.sync.outgoing import PAGE_SIZE, PAGE_TIMEOUT_MS | ||||
| from authentik.lib.sync.outgoing.base import Direction | ||||
| from authentik.lib.sync.outgoing.exceptions import ( | ||||
|     BadRequestSyncException, | ||||
| @ -27,11 +21,12 @@ from authentik.lib.sync.outgoing.exceptions import ( | ||||
| ) | ||||
| from authentik.lib.sync.outgoing.models import OutgoingSyncProvider | ||||
| from authentik.lib.utils.reflection import class_to_path, path_to_class | ||||
| from authentik.tasks.models import Task | ||||
|  | ||||
|  | ||||
| class SyncTasks: | ||||
|     """Container for all sync 'tasks' (this class doesn't actually contain celery | ||||
|     tasks due to celery's magic, however exposes a number of functions to be called from tasks)""" | ||||
|     """Container for all sync 'tasks' (this class doesn't actually contain | ||||
|     tasks due to dramatiq's magic, however exposes a number of functions to be called from tasks)""" | ||||
|  | ||||
|     logger: BoundLogger | ||||
|  | ||||
| @ -39,107 +34,104 @@ class SyncTasks: | ||||
|         super().__init__() | ||||
|         self._provider_model = provider_model | ||||
|  | ||||
|     def sync_all(self, single_sync: Callable[[int], None]): | ||||
|         for provider in self._provider_model.objects.filter( | ||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False) | ||||
|         ): | ||||
|             self.trigger_single_task(provider, single_sync) | ||||
|  | ||||
|     def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]): | ||||
|         """Wrapper single sync task that correctly sets time limits based | ||||
|         on the amount of objects that will be synced""" | ||||
|         users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE) | ||||
|         groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE) | ||||
|         soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT | ||||
|         time_limit = soft_time_limit * 1.5 | ||||
|         return sync_task.apply_async( | ||||
|             (provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit) | ||||
|         ) | ||||
|  | ||||
|     def sync_single( | ||||
|     def sync_paginator( | ||||
|         self, | ||||
|         task: SystemTask, | ||||
|         provider_pk: int, | ||||
|         sync_objects: Callable[[int, int], list[str]], | ||||
|         current_task: Task, | ||||
|         provider: OutgoingSyncProvider, | ||||
|         sync_objects: Actor[[str, int, int, bool], None], | ||||
|         paginator: Paginator, | ||||
|         object_type: type[User | Group], | ||||
|         **options, | ||||
|     ): | ||||
|         tasks = [] | ||||
|         for page in paginator.page_range: | ||||
|             page_sync = sync_objects.message_with_options( | ||||
|                 args=(class_to_path(object_type), page, provider.pk), | ||||
|                 time_limit=PAGE_TIMEOUT_MS, | ||||
|                 # Assign tasks to the same schedule as the current one | ||||
|                 rel_obj=current_task.rel_obj, | ||||
|                 **options, | ||||
|             ) | ||||
|             tasks.append(page_sync) | ||||
|         return tasks | ||||
|  | ||||
|     def sync( | ||||
|         self, | ||||
|         provider_pk: int, | ||||
|         sync_objects: Actor[[str, int, int, bool], None], | ||||
|     ): | ||||
|         task: Task = CurrentTask.get_task() | ||||
|         self.logger = get_logger().bind( | ||||
|             provider_type=class_to_path(self._provider_model), | ||||
|             provider_pk=provider_pk, | ||||
|         ) | ||||
|         provider = self._provider_model.objects.filter( | ||||
|         provider: OutgoingSyncProvider = self._provider_model.objects.filter( | ||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False), | ||||
|             pk=provider_pk, | ||||
|         ).first() | ||||
|         if not provider: | ||||
|             task.warning("No provider found. Is it assigned to an application?") | ||||
|             return | ||||
|         task.set_uid(slugify(provider.name)) | ||||
|         messages = [] | ||||
|         messages.append(_("Starting full provider sync")) | ||||
|         task.info("Starting full provider sync") | ||||
|         self.logger.debug("Starting provider sync") | ||||
|         users_paginator = Paginator(provider.get_object_qs(User), PAGE_SIZE) | ||||
|         groups_paginator = Paginator(provider.get_object_qs(Group), PAGE_SIZE) | ||||
|         with allow_join_result(), provider.sync_lock as lock_acquired: | ||||
|         with provider.sync_lock as lock_acquired: | ||||
|             if not lock_acquired: | ||||
|                 task.info("Synchronization is already running. Skipping.") | ||||
|                 self.logger.debug("Failed to acquire sync lock, skipping", provider=provider.name) | ||||
|                 return | ||||
|             try: | ||||
|                 messages.append(_("Syncing users")) | ||||
|                 user_results = ( | ||||
|                     group( | ||||
|                         [ | ||||
|                             sync_objects.signature( | ||||
|                                 args=(class_to_path(User), page, provider_pk), | ||||
|                                 time_limit=PAGE_TIMEOUT, | ||||
|                                 soft_time_limit=PAGE_TIMEOUT, | ||||
|                             ) | ||||
|                             for page in users_paginator.page_range | ||||
|                         ] | ||||
|                 users_tasks = group( | ||||
|                     self.sync_paginator( | ||||
|                         current_task=task, | ||||
|                         provider=provider, | ||||
|                         sync_objects=sync_objects, | ||||
|                         paginator=provider.get_paginator(User), | ||||
|                         object_type=User, | ||||
|                     ) | ||||
|                     .apply_async() | ||||
|                     .get() | ||||
|                 ) | ||||
|                 for result in user_results: | ||||
|                     for msg in result: | ||||
|                         messages.append(LogEvent(**msg)) | ||||
|                 messages.append(_("Syncing groups")) | ||||
|                 group_results = ( | ||||
|                     group( | ||||
|                         [ | ||||
|                             sync_objects.signature( | ||||
|                                 args=(class_to_path(Group), page, provider_pk), | ||||
|                                 time_limit=PAGE_TIMEOUT, | ||||
|                                 soft_time_limit=PAGE_TIMEOUT, | ||||
|                             ) | ||||
|                             for page in groups_paginator.page_range | ||||
|                         ] | ||||
|                 group_tasks = group( | ||||
|                     self.sync_paginator( | ||||
|                         current_task=task, | ||||
|                         provider=provider, | ||||
|                         sync_objects=sync_objects, | ||||
|                         paginator=provider.get_paginator(Group), | ||||
|                         object_type=Group, | ||||
|                     ) | ||||
|                     .apply_async() | ||||
|                     .get() | ||||
|                 ) | ||||
|                 for result in group_results: | ||||
|                     for msg in result: | ||||
|                         messages.append(LogEvent(**msg)) | ||||
|                 users_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(User)) | ||||
|                 group_tasks.run().wait(timeout=provider.get_object_sync_time_limit_ms(Group)) | ||||
|             except TransientSyncException as exc: | ||||
|                 self.logger.warning("transient sync exception", exc=exc) | ||||
|                 raise task.retry(exc=exc) from exc | ||||
|                 task.warning("Sync encountered a transient exception. Retrying", exc=exc) | ||||
|                 raise Retry() from exc | ||||
|             except StopSync as exc: | ||||
|                 task.set_error(exc) | ||||
|                 task.error(exc) | ||||
|                 return | ||||
|         task.set_status(TaskStatus.SUCCESSFUL, *messages) | ||||
|  | ||||
|     def sync_objects( | ||||
|         self, object_type: str, page: int, provider_pk: int, override_dry_run=False, **filter | ||||
|         self, | ||||
|         object_type: str, | ||||
|         page: int, | ||||
|         provider_pk: int, | ||||
|         override_dry_run=False, | ||||
|         **filter, | ||||
|     ): | ||||
|         task: Task = CurrentTask.get_task() | ||||
|         _object_type: type[Model] = path_to_class(object_type) | ||||
|         self.logger = get_logger().bind( | ||||
|             provider_type=class_to_path(self._provider_model), | ||||
|             provider_pk=provider_pk, | ||||
|             object_type=object_type, | ||||
|         ) | ||||
|         messages = [] | ||||
|         provider = self._provider_model.objects.filter(pk=provider_pk).first() | ||||
|         provider: OutgoingSyncProvider = self._provider_model.objects.filter( | ||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False), | ||||
|             pk=provider_pk, | ||||
|         ).first() | ||||
|         if not provider: | ||||
|             return messages | ||||
|             task.warning("No provider found. Is it assigned to an application?") | ||||
|             return | ||||
|         task.set_uid(slugify(provider.name)) | ||||
|         # Override dry run mode if requested, however don't save the provider | ||||
|         # so that scheduled sync tasks still run in dry_run mode | ||||
|         if override_dry_run: | ||||
| @ -147,25 +139,13 @@ class SyncTasks: | ||||
|         try: | ||||
|             client = provider.client_for_model(_object_type) | ||||
|         except TransientSyncException: | ||||
|             return messages | ||||
|             return | ||||
|         paginator = Paginator(provider.get_object_qs(_object_type).filter(**filter), PAGE_SIZE) | ||||
|         if client.can_discover: | ||||
|             self.logger.debug("starting discover") | ||||
|             client.discover() | ||||
|         self.logger.debug("starting sync for page", page=page) | ||||
|         messages.append( | ||||
|             asdict( | ||||
|                 LogEvent( | ||||
|                     _( | ||||
|                         "Syncing page {page} of {object_type}".format( | ||||
|                             page=page, object_type=_object_type._meta.verbose_name_plural | ||||
|                         ) | ||||
|                     ), | ||||
|                     log_level="info", | ||||
|                     logger=f"{provider._meta.verbose_name}@{object_type}", | ||||
|                 ) | ||||
|             ) | ||||
|         ) | ||||
|         task.info(f"Syncing page {page} or {_object_type._meta.verbose_name_plural}") | ||||
|         for obj in paginator.page(page).object_list: | ||||
|             obj: Model | ||||
|             try: | ||||
| @ -174,89 +154,58 @@ class SyncTasks: | ||||
|                 self.logger.debug("skipping object due to SkipObject", obj=obj) | ||||
|                 continue | ||||
|             except DryRunRejected as exc: | ||||
|                 messages.append( | ||||
|                     asdict( | ||||
|                         LogEvent( | ||||
|                             _("Dropping mutating request due to dry run"), | ||||
|                             log_level="info", | ||||
|                             logger=f"{provider._meta.verbose_name}@{object_type}", | ||||
|                             attributes={ | ||||
|                                 "obj": sanitize_item(obj), | ||||
|                                 "method": exc.method, | ||||
|                                 "url": exc.url, | ||||
|                                 "body": exc.body, | ||||
|                             }, | ||||
|                         ) | ||||
|                     ) | ||||
|                 task.info( | ||||
|                     "Dropping mutating request due to dry run", | ||||
|                     obj=sanitize_item(obj), | ||||
|                     method=exc.method, | ||||
|                     url=exc.url, | ||||
|                     body=exc.body, | ||||
|                 ) | ||||
|             except BadRequestSyncException as exc: | ||||
|                 self.logger.warning("failed to sync object", exc=exc, obj=obj) | ||||
|                 messages.append( | ||||
|                     asdict( | ||||
|                         LogEvent( | ||||
|                             _( | ||||
|                                 ( | ||||
|                                     "Failed to sync {object_type} {object_name} " | ||||
|                                     "due to error: {error}" | ||||
|                                 ).format_map( | ||||
|                                     { | ||||
|                                         "object_type": obj._meta.verbose_name, | ||||
|                                         "object_name": str(obj), | ||||
|                                         "error": str(exc), | ||||
|                                     } | ||||
|                                 ) | ||||
|                             ), | ||||
|                             log_level="warning", | ||||
|                             logger=f"{provider._meta.verbose_name}@{object_type}", | ||||
|                             attributes={"arguments": exc.args[1:], "obj": sanitize_item(obj)}, | ||||
|                         ) | ||||
|                     ) | ||||
|                 task.warning( | ||||
|                     f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to error: {str(exc)}", | ||||
|                     arguments=exc.args[1:], | ||||
|                     obj=sanitize_item(obj), | ||||
|                 ) | ||||
|             except TransientSyncException as exc: | ||||
|                 self.logger.warning("failed to sync object", exc=exc, user=obj) | ||||
|                 messages.append( | ||||
|                     asdict( | ||||
|                         LogEvent( | ||||
|                             _( | ||||
|                                 ( | ||||
|                                     "Failed to sync {object_type} {object_name} " | ||||
|                                     "due to transient error: {error}" | ||||
|                                 ).format_map( | ||||
|                                     { | ||||
|                                         "object_type": obj._meta.verbose_name, | ||||
|                                         "object_name": str(obj), | ||||
|                                         "error": str(exc), | ||||
|                                     } | ||||
|                                 ) | ||||
|                             ), | ||||
|                             log_level="warning", | ||||
|                             logger=f"{provider._meta.verbose_name}@{object_type}", | ||||
|                             attributes={"obj": sanitize_item(obj)}, | ||||
|                         ) | ||||
|                     ) | ||||
|                 task.warning( | ||||
|                     f"Failed to sync {obj._meta.verbose_name} {str(obj)} due to " | ||||
|                     "transient error: {str(exc)}", | ||||
|                     obj=sanitize_item(obj), | ||||
|                 ) | ||||
|             except StopSync as exc: | ||||
|                 self.logger.warning("Stopping sync", exc=exc) | ||||
|                 messages.append( | ||||
|                     asdict( | ||||
|                         LogEvent( | ||||
|                             _( | ||||
|                                 "Stopping sync due to error: {error}".format_map( | ||||
|                                     { | ||||
|                                         "error": exc.detail(), | ||||
|                                     } | ||||
|                                 ) | ||||
|                             ), | ||||
|                             log_level="warning", | ||||
|                             logger=f"{provider._meta.verbose_name}@{object_type}", | ||||
|                             attributes={"obj": sanitize_item(obj)}, | ||||
|                         ) | ||||
|                     ) | ||||
|                 task.warning( | ||||
|                     f"Stopping sync due to error: {exc.detail()}", | ||||
|                     obj=sanitize_item(obj), | ||||
|                 ) | ||||
|                 break | ||||
|         return messages | ||||
|  | ||||
|     def sync_signal_direct(self, model: str, pk: str | int, raw_op: str): | ||||
|     def sync_signal_direct_dispatch( | ||||
|         self, | ||||
|         task_sync_signal_direct: Actor[[str, str | int, int, str], None], | ||||
|         model: str, | ||||
|         pk: str | int, | ||||
|         raw_op: str, | ||||
|     ): | ||||
|         for provider in self._provider_model.objects.filter( | ||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False) | ||||
|         ): | ||||
|             task_sync_signal_direct.send_with_options( | ||||
|                 args=(model, pk, provider.pk, raw_op), | ||||
|                 rel_obj=provider, | ||||
|             ) | ||||
|  | ||||
|     def sync_signal_direct( | ||||
|         self, | ||||
|         model: str, | ||||
|         pk: str | int, | ||||
|         provider_pk: int, | ||||
|         raw_op: str, | ||||
|     ): | ||||
|         task: Task = CurrentTask.get_task() | ||||
|         self.logger = get_logger().bind( | ||||
|             provider_type=class_to_path(self._provider_model), | ||||
|         ) | ||||
| @ -264,65 +213,108 @@ class SyncTasks: | ||||
|         instance = model_class.objects.filter(pk=pk).first() | ||||
|         if not instance: | ||||
|             return | ||||
|         provider: OutgoingSyncProvider = self._provider_model.objects.filter( | ||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False), | ||||
|             pk=provider_pk, | ||||
|         ).first() | ||||
|         if not provider: | ||||
|             task.warning("No provider found. Is it assigned to an application?") | ||||
|             return | ||||
|         task.set_uid(slugify(provider.name)) | ||||
|         operation = Direction(raw_op) | ||||
|         client = provider.client_for_model(instance.__class__) | ||||
|         # Check if the object is allowed within the provider's restrictions | ||||
|         queryset = provider.get_object_qs(instance.__class__) | ||||
|         if not queryset: | ||||
|             return | ||||
|  | ||||
|         # The queryset we get from the provider must include the instance we've got given | ||||
|         # otherwise ignore this provider | ||||
|         if not queryset.filter(pk=instance.pk).exists(): | ||||
|             return | ||||
|  | ||||
|         try: | ||||
|             if operation == Direction.add: | ||||
|                 client.write(instance) | ||||
|             if operation == Direction.remove: | ||||
|                 client.delete(instance) | ||||
|         except TransientSyncException as exc: | ||||
|             raise Retry() from exc | ||||
|         except SkipObjectException: | ||||
|             return | ||||
|         except DryRunRejected as exc: | ||||
|             self.logger.info("Rejected dry-run event", exc=exc) | ||||
|         except StopSync as exc: | ||||
|             self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) | ||||
|  | ||||
|     def sync_signal_m2m_dispatch( | ||||
|         self, | ||||
|         task_sync_signal_m2m: Actor[[str, int, str, list[int]], None], | ||||
|         instance_pk: str, | ||||
|         action: str, | ||||
|         pk_set: list[int], | ||||
|         reverse: bool, | ||||
|     ): | ||||
|         for provider in self._provider_model.objects.filter( | ||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False) | ||||
|         ): | ||||
|             client = provider.client_for_model(instance.__class__) | ||||
|             # Check if the object is allowed within the provider's restrictions | ||||
|             queryset = provider.get_object_qs(instance.__class__) | ||||
|             if not queryset: | ||||
|                 continue | ||||
|             # reverse: instance is a Group, pk_set is a list of user pks | ||||
|             # non-reverse: instance is a User, pk_set is a list of groups | ||||
|             if reverse: | ||||
|                 task_sync_signal_m2m.send_with_options( | ||||
|                     args=(instance_pk, provider.pk, action, list(pk_set)), | ||||
|                     rel_obj=provider, | ||||
|                 ) | ||||
|             else: | ||||
|                 for pk in pk_set: | ||||
|                     task_sync_signal_m2m.send_with_options( | ||||
|                         args=(pk, provider.pk, action, [instance_pk]), | ||||
|                         rel_obj=provider, | ||||
|                     ) | ||||
|  | ||||
|             # The queryset we get from the provider must include the instance we've got given | ||||
|             # otherwise ignore this provider | ||||
|             if not queryset.filter(pk=instance.pk).exists(): | ||||
|                 continue | ||||
|  | ||||
|             try: | ||||
|                 if operation == Direction.add: | ||||
|                     client.write(instance) | ||||
|                 if operation == Direction.remove: | ||||
|                     client.delete(instance) | ||||
|             except TransientSyncException as exc: | ||||
|                 raise Retry() from exc | ||||
|             except SkipObjectException: | ||||
|                 continue | ||||
|             except DryRunRejected as exc: | ||||
|                 self.logger.info("Rejected dry-run event", exc=exc) | ||||
|             except StopSync as exc: | ||||
|                 self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) | ||||
|  | ||||
|     def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]): | ||||
|     def sync_signal_m2m( | ||||
|         self, | ||||
|         group_pk: str, | ||||
|         provider_pk: int, | ||||
|         action: str, | ||||
|         pk_set: list[int], | ||||
|     ): | ||||
|         task: Task = CurrentTask.get_task() | ||||
|         self.logger = get_logger().bind( | ||||
|             provider_type=class_to_path(self._provider_model), | ||||
|         ) | ||||
|         group = Group.objects.filter(pk=group_pk).first() | ||||
|         if not group: | ||||
|             return | ||||
|         for provider in self._provider_model.objects.filter( | ||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False) | ||||
|         ): | ||||
|             # Check if the object is allowed within the provider's restrictions | ||||
|             queryset: QuerySet = provider.get_object_qs(Group) | ||||
|             # The queryset we get from the provider must include the instance we've got given | ||||
|             # otherwise ignore this provider | ||||
|             if not queryset.filter(pk=group_pk).exists(): | ||||
|                 continue | ||||
|         provider: OutgoingSyncProvider = self._provider_model.objects.filter( | ||||
|             Q(backchannel_application__isnull=False) | Q(application__isnull=False), | ||||
|             pk=provider_pk, | ||||
|         ).first() | ||||
|         if not provider: | ||||
|             task.warning("No provider found. Is it assigned to an application?") | ||||
|             return | ||||
|         task.set_uid(slugify(provider.name)) | ||||
|  | ||||
|             client = provider.client_for_model(Group) | ||||
|             try: | ||||
|                 operation = None | ||||
|                 if action == "post_add": | ||||
|                     operation = Direction.add | ||||
|                 if action == "post_remove": | ||||
|                     operation = Direction.remove | ||||
|                 client.update_group(group, operation, pk_set) | ||||
|             except TransientSyncException as exc: | ||||
|                 raise Retry() from exc | ||||
|             except SkipObjectException: | ||||
|                 continue | ||||
|             except DryRunRejected as exc: | ||||
|                 self.logger.info("Rejected dry-run event", exc=exc) | ||||
|             except StopSync as exc: | ||||
|                 self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) | ||||
|         # Check if the object is allowed within the provider's restrictions | ||||
|         queryset: QuerySet = provider.get_object_qs(Group) | ||||
|         # The queryset we get from the provider must include the instance we've got given | ||||
|         # otherwise ignore this provider | ||||
|         if not queryset.filter(pk=group_pk).exists(): | ||||
|             return | ||||
|  | ||||
|         client = provider.client_for_model(Group) | ||||
|         try: | ||||
|             operation = None | ||||
|             if action == "post_add": | ||||
|                 operation = Direction.add | ||||
|             if action == "post_remove": | ||||
|                 operation = Direction.remove | ||||
|             client.update_group(group, operation, pk_set) | ||||
|         except TransientSyncException as exc: | ||||
|             raise Retry() from exc | ||||
|         except SkipObjectException: | ||||
|             return | ||||
|         except DryRunRejected as exc: | ||||
|             self.logger.info("Rejected dry-run event", exc=exc) | ||||
|         except StopSync as exc: | ||||
|             self.logger.warning("Stopping sync", exc=exc, provider_pk=provider.pk) | ||||
|  | ||||
| @ -5,6 +5,8 @@ from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
| from authentik.tasks.schedules.lib import ScheduleSpec | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
| @ -60,3 +62,27 @@ class AuthentikOutpostConfig(ManagedAppConfig): | ||||
|                 outpost.save() | ||||
|         else: | ||||
|             Outpost.objects.filter(managed=MANAGED_OUTPOST).delete() | ||||
|  | ||||
|     @property | ||||
|     def tenant_schedule_specs(self) -> list[ScheduleSpec]: | ||||
|         from authentik.outposts.tasks import outpost_token_ensurer | ||||
|  | ||||
|         return [ | ||||
|             ScheduleSpec( | ||||
|                 actor=outpost_token_ensurer, | ||||
|                 crontab=f"{fqdn_rand('outpost_token_ensurer')} */8 * * *", | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
|     @property | ||||
|     def global_schedule_specs(self) -> list[ScheduleSpec]: | ||||
|         from authentik.outposts.tasks import outpost_connection_discovery | ||||
|  | ||||
|         return [ | ||||
|             ScheduleSpec( | ||||
|                 actor=outpost_connection_discovery, | ||||
|                 crontab=f"{fqdn_rand('outpost_connection_discovery')} */8 * * *", | ||||
|                 send_on_startup=True, | ||||
|                 paused=not CONFIG.get_bool("outposts.discover"), | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
| @ -101,7 +101,13 @@ class KubernetesController(BaseController): | ||||
|             all_logs = [] | ||||
|             for reconcile_key in self.reconcile_order: | ||||
|                 if reconcile_key in self.outpost.config.kubernetes_disabled_components: | ||||
|                     all_logs += [f"{reconcile_key.title()}: Disabled"] | ||||
|                     all_logs.append( | ||||
|                         LogEvent( | ||||
|                             log_level="info", | ||||
|                             event=f"{reconcile_key.title()}: Disabled", | ||||
|                             logger=str(type(self)), | ||||
|                         ) | ||||
|                     ) | ||||
|                     continue | ||||
|                 with capture_logs() as logs: | ||||
|                     reconciler_cls = self.reconcilers.get(reconcile_key) | ||||
| @ -134,7 +140,13 @@ class KubernetesController(BaseController): | ||||
|             all_logs = [] | ||||
|             for reconcile_key in self.reconcile_order: | ||||
|                 if reconcile_key in self.outpost.config.kubernetes_disabled_components: | ||||
|                     all_logs += [f"{reconcile_key.title()}: Disabled"] | ||||
|                     all_logs.append( | ||||
|                         LogEvent( | ||||
|                             log_level="info", | ||||
|                             event=f"{reconcile_key.title()}: Disabled", | ||||
|                             logger=str(type(self)), | ||||
|                         ) | ||||
|                     ) | ||||
|                     continue | ||||
|                 with capture_logs() as logs: | ||||
|                     reconciler_cls = self.reconcilers.get(reconcile_key) | ||||
|  | ||||
| @ -36,7 +36,10 @@ from authentik.lib.config import CONFIG | ||||
| from authentik.lib.models import InheritanceForeignKey, SerializerModel | ||||
| from authentik.lib.sentry import SentryIgnoredException | ||||
| from authentik.lib.utils.errors import exception_to_string | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
| from authentik.outposts.controllers.k8s.utils import get_namespace | ||||
| from authentik.tasks.schedules.lib import ScheduleSpec | ||||
| from authentik.tasks.schedules.models import ScheduledModel | ||||
|  | ||||
| OUR_VERSION = parse(__version__) | ||||
| OUTPOST_HELLO_INTERVAL = 10 | ||||
| @ -115,7 +118,7 @@ class OutpostServiceConnectionState: | ||||
|     healthy: bool | ||||
|  | ||||
|  | ||||
| class OutpostServiceConnection(models.Model): | ||||
| class OutpostServiceConnection(ScheduledModel, models.Model): | ||||
|     """Connection details for an Outpost Controller, like Docker or Kubernetes""" | ||||
|  | ||||
|     uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) | ||||
| @ -145,11 +148,11 @@ class OutpostServiceConnection(models.Model): | ||||
|     @property | ||||
|     def state(self) -> OutpostServiceConnectionState: | ||||
|         """Get state of service connection""" | ||||
|         from authentik.outposts.tasks import outpost_service_connection_state | ||||
|         from authentik.outposts.tasks import outpost_service_connection_monitor | ||||
|  | ||||
|         state = cache.get(self.state_key, None) | ||||
|         if not state: | ||||
|             outpost_service_connection_state.delay(self.pk) | ||||
|             outpost_service_connection_monitor.send_with_options(args=(self.pk), rel_obj=self) | ||||
|             return OutpostServiceConnectionState("", False) | ||||
|         return state | ||||
|  | ||||
| @ -160,6 +163,20 @@ class OutpostServiceConnection(models.Model): | ||||
|         # since the response doesn't use the correct inheritance | ||||
|         return "" | ||||
|  | ||||
|     @property | ||||
|     def schedule_specs(self) -> list[ScheduleSpec]: | ||||
|         from authentik.outposts.tasks import outpost_service_connection_monitor | ||||
|  | ||||
|         return [ | ||||
|             ScheduleSpec( | ||||
|                 actor=outpost_service_connection_monitor, | ||||
|                 uid=self.pk, | ||||
|                 args=(self.pk,), | ||||
|                 crontab="3-59/15 * * * *", | ||||
|                 send_on_save=True, | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
|  | ||||
| class DockerServiceConnection(SerializerModel, OutpostServiceConnection): | ||||
|     """Service Connection to a Docker endpoint""" | ||||
| @ -244,7 +261,7 @@ class KubernetesServiceConnection(SerializerModel, OutpostServiceConnection): | ||||
|         return "ak-service-connection-kubernetes-form" | ||||
|  | ||||
|  | ||||
| class Outpost(SerializerModel, ManagedModel): | ||||
| class Outpost(ScheduledModel, SerializerModel, ManagedModel): | ||||
|     """Outpost instance which manages a service user and token""" | ||||
|  | ||||
|     uuid = models.UUIDField(default=uuid4, editable=False, primary_key=True) | ||||
| @ -298,6 +315,21 @@ class Outpost(SerializerModel, ManagedModel): | ||||
|         """Username for service user""" | ||||
|         return f"ak-outpost-{self.uuid.hex}" | ||||
|  | ||||
|     @property | ||||
|     def schedule_specs(self) -> list[ScheduleSpec]: | ||||
|         from authentik.outposts.tasks import outpost_controller | ||||
|  | ||||
|         return [ | ||||
|             ScheduleSpec( | ||||
|                 actor=outpost_controller, | ||||
|                 uid=self.pk, | ||||
|                 args=(self.pk,), | ||||
|                 kwargs={"action": "up", "from_cache": False}, | ||||
|                 crontab=f"{fqdn_rand('outpost_controller')} */4 * * *", | ||||
|                 send_on_save=True, | ||||
|             ), | ||||
|         ] | ||||
|  | ||||
|     def build_user_permissions(self, user: User): | ||||
|         """Create per-object and global permissions for outpost service-account""" | ||||
|         # To ensure the user only has the correct permissions, we delete all of them and re-add | ||||
|  | ||||
| @ -1,28 +0,0 @@ | ||||
| """Outposts Settings""" | ||||
|  | ||||
| from celery.schedules import crontab | ||||
|  | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
|  | ||||
| CELERY_BEAT_SCHEDULE = { | ||||
|     "outposts_controller": { | ||||
|         "task": "authentik.outposts.tasks.outpost_controller_all", | ||||
|         "schedule": crontab(minute=fqdn_rand("outposts_controller"), hour="*/4"), | ||||
|         "options": {"queue": "authentik_scheduled"}, | ||||
|     }, | ||||
|     "outposts_service_connection_check": { | ||||
|         "task": "authentik.outposts.tasks.outpost_service_connection_monitor", | ||||
|         "schedule": crontab(minute="3-59/15"), | ||||
|         "options": {"queue": "authentik_scheduled"}, | ||||
|     }, | ||||
|     "outpost_token_ensurer": { | ||||
|         "task": "authentik.outposts.tasks.outpost_token_ensurer", | ||||
|         "schedule": crontab(minute=fqdn_rand("outpost_token_ensurer"), hour="*/8"), | ||||
|         "options": {"queue": "authentik_scheduled"}, | ||||
|     }, | ||||
|     "outpost_connection_discovery": { | ||||
|         "task": "authentik.outposts.tasks.outpost_connection_discovery", | ||||
|         "schedule": crontab(minute=fqdn_rand("outpost_connection_discovery"), hour="*/8"), | ||||
|         "options": {"queue": "authentik_scheduled"}, | ||||
|     }, | ||||
| } | ||||
| @ -1,7 +1,6 @@ | ||||
| """authentik outpost signals""" | ||||
|  | ||||
| from django.core.cache import cache | ||||
| from django.db.models import Model | ||||
| from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save | ||||
| from django.dispatch import receiver | ||||
| from structlog.stdlib import get_logger | ||||
| @ -9,27 +8,19 @@ from structlog.stdlib import get_logger | ||||
| from authentik.brands.models import Brand | ||||
| from authentik.core.models import AuthenticatedSession, Provider | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.lib.utils.reflection import class_to_path | ||||
| from authentik.outposts.models import Outpost, OutpostServiceConnection | ||||
| from authentik.outposts.models import Outpost, OutpostModel, OutpostServiceConnection | ||||
| from authentik.outposts.tasks import ( | ||||
|     CACHE_KEY_OUTPOST_DOWN, | ||||
|     outpost_controller, | ||||
|     outpost_post_save, | ||||
|     outpost_send_update, | ||||
|     outpost_session_end, | ||||
| ) | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| UPDATE_TRIGGERING_MODELS = ( | ||||
|     Outpost, | ||||
|     OutpostServiceConnection, | ||||
|     Provider, | ||||
|     CertificateKeyPair, | ||||
|     Brand, | ||||
| ) | ||||
|  | ||||
|  | ||||
| @receiver(pre_save, sender=Outpost) | ||||
| def pre_save_outpost(sender, instance: Outpost, **_): | ||||
| def outpost_pre_save(sender, instance: Outpost, **_): | ||||
|     """Pre-save checks for an outpost, if the name or config.kubernetes_namespace changes, | ||||
|     we call down and then wait for the up after save""" | ||||
|     old_instances = Outpost.objects.filter(pk=instance.pk) | ||||
| @ -44,43 +35,89 @@ def pre_save_outpost(sender, instance: Outpost, **_): | ||||
|     if bool(dirty): | ||||
|         LOGGER.info("Outpost needs re-deployment due to changes", instance=instance) | ||||
|         cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, old_instance) | ||||
|         outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) | ||||
|         outpost_controller.send_with_options( | ||||
|             args=(instance.pk.hex,), | ||||
|             kwargs={"action": "down", "from_cache": True}, | ||||
|             rel_obj=instance, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @receiver(m2m_changed, sender=Outpost.providers.through) | ||||
| def m2m_changed_update(sender, instance: Model, action: str, **_): | ||||
| def outpost_m2m_changed(sender, instance: Outpost | Provider, action: str, **_): | ||||
|     """Update outpost on m2m change, when providers are added or removed""" | ||||
|     if action in ["post_add", "post_remove", "post_clear"]: | ||||
|         outpost_post_save.delay(class_to_path(instance.__class__), instance.pk) | ||||
|     if action not in ["post_add", "post_remove", "post_clear"]: | ||||
|         return | ||||
|     if isinstance(instance, Outpost): | ||||
|         outpost_controller.send_with_options( | ||||
|             args=(instance.pk,), | ||||
|             rel_obj=instance.service_connection, | ||||
|         ) | ||||
|         outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance) | ||||
|     elif isinstance(instance, OutpostModel): | ||||
|         for outpost in instance.outpost_set.all(): | ||||
|             outpost_controller.send_with_options( | ||||
|                 args=(instance.pk,), | ||||
|                 rel_obj=instance.service_connection, | ||||
|             ) | ||||
|             outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost) | ||||
|  | ||||
|  | ||||
| @receiver(post_save) | ||||
| def post_save_update(sender, instance: Model, created: bool, **_): | ||||
|     """If an Outpost is saved, Ensure that token is created/updated | ||||
|  | ||||
|     If an OutpostModel, or a model that is somehow connected to an OutpostModel is saved, | ||||
|     we send a message down the relevant OutpostModels WS connection to trigger an update""" | ||||
|     if instance.__module__ == "django.db.migrations.recorder": | ||||
|         return | ||||
|     if instance.__module__ == "__fake__": | ||||
|         return | ||||
|     if not isinstance(instance, UPDATE_TRIGGERING_MODELS): | ||||
|         return | ||||
|     if isinstance(instance, Outpost) and created: | ||||
| @receiver(post_save, sender=Outpost) | ||||
| def outpost_post_save(sender, instance: Outpost, created: bool, **_): | ||||
|     if created: | ||||
|         LOGGER.info("New outpost saved, ensuring initial token and user are created") | ||||
|         _ = instance.token | ||||
|     outpost_post_save.delay(class_to_path(instance.__class__), instance.pk) | ||||
|     outpost_controller.send_with_options(args=(instance.pk,), rel_obj=instance.service_connection) | ||||
|     outpost_send_update.send_with_options(args=(instance.pk,), rel_obj=instance) | ||||
|  | ||||
|  | ||||
| def outpost_related_post_save(sender, instance: OutpostServiceConnection | OutpostModel, **_): | ||||
|     for outpost in instance.outpost_set.all(): | ||||
|         outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost) | ||||
|  | ||||
|  | ||||
| post_save.connect(outpost_related_post_save, sender=OutpostServiceConnection, weak=False) | ||||
| for subclass in OutpostModel.__subclasses__(): | ||||
|     post_save.connect(outpost_related_post_save, sender=subclass, weak=False) | ||||
|  | ||||
|  | ||||
| def outpost_reverse_related_post_save(sender, instance: CertificateKeyPair | Brand, **_): | ||||
|     for field in instance._meta.get_fields(): | ||||
|         # Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms) | ||||
|         # are used, and if it has a value | ||||
|         if not hasattr(field, "related_model"): | ||||
|             continue | ||||
|         if not field.related_model: | ||||
|             continue | ||||
|         if not issubclass(field.related_model, OutpostModel): | ||||
|             continue | ||||
|  | ||||
|         field_name = f"{field.name}_set" | ||||
|         if not hasattr(instance, field_name): | ||||
|             continue | ||||
|  | ||||
|         LOGGER.debug("triggering outpost update from field", field=field.name) | ||||
|         # Because the Outpost Model has an M2M to Provider, | ||||
|         # we have to iterate over the entire QS | ||||
|         for reverse in getattr(instance, field_name).all(): | ||||
|             if isinstance(reverse, OutpostModel): | ||||
|                 for outpost in reverse.outpost_set.all(): | ||||
|                     outpost_send_update.send_with_options(args=(outpost.pk,), rel_obj=outpost) | ||||
|  | ||||
|  | ||||
| post_save.connect(outpost_reverse_related_post_save, sender=Brand, weak=False) | ||||
| post_save.connect(outpost_reverse_related_post_save, sender=CertificateKeyPair, weak=False) | ||||
|  | ||||
|  | ||||
| @receiver(pre_delete, sender=Outpost) | ||||
| def pre_delete_cleanup(sender, instance: Outpost, **_): | ||||
| def outpost_pre_delete_cleanup(sender, instance: Outpost, **_): | ||||
|     """Ensure that Outpost's user is deleted (which will delete the token through cascade)""" | ||||
|     instance.user.delete() | ||||
|     cache.set(CACHE_KEY_OUTPOST_DOWN % instance.pk.hex, instance) | ||||
|     outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) | ||||
|     outpost_controller.send(instance.pk.hex, action="down", from_cache=True) | ||||
|  | ||||
|  | ||||
| @receiver(pre_delete, sender=AuthenticatedSession) | ||||
| def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | ||||
| def outpost_logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | ||||
|     """Catch logout by expiring sessions being deleted""" | ||||
|     outpost_session_end.delay(instance.session.session_key) | ||||
|     outpost_session_end.send(instance.session.session_key) | ||||
|  | ||||
| @ -10,19 +10,17 @@ from urllib.parse import urlparse | ||||
| from asgiref.sync import async_to_sync | ||||
| from channels.layers import get_channel_layer | ||||
| from django.core.cache import cache | ||||
| from django.db import DatabaseError, InternalError, ProgrammingError | ||||
| from django.db.models.base import Model | ||||
| from django.utils.text import slugify | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from django_dramatiq_postgres.middleware import CurrentTask | ||||
| from docker.constants import DEFAULT_UNIX_SOCKET | ||||
| from dramatiq.actor import actor | ||||
| from kubernetes.config.incluster_config import SERVICE_TOKEN_FILENAME | ||||
| from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION | ||||
| from structlog.stdlib import get_logger | ||||
| from yaml import safe_load | ||||
|  | ||||
| from authentik.events.models import TaskStatus | ||||
| from authentik.events.system_tasks import SystemTask, prefill_task | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.lib.utils.reflection import path_to_class | ||||
| from authentik.outposts.consumer import OUTPOST_GROUP | ||||
| from authentik.outposts.controllers.base import BaseController, ControllerException | ||||
| from authentik.outposts.controllers.docker import DockerClient | ||||
| @ -31,7 +29,6 @@ from authentik.outposts.models import ( | ||||
|     DockerServiceConnection, | ||||
|     KubernetesServiceConnection, | ||||
|     Outpost, | ||||
|     OutpostModel, | ||||
|     OutpostServiceConnection, | ||||
|     OutpostType, | ||||
|     ServiceConnectionInvalid, | ||||
| @ -44,7 +41,7 @@ from authentik.providers.rac.controllers.docker import RACDockerController | ||||
| from authentik.providers.rac.controllers.kubernetes import RACKubernetesController | ||||
| from authentik.providers.radius.controllers.docker import RadiusDockerController | ||||
| from authentik.providers.radius.controllers.kubernetes import RadiusKubernetesController | ||||
| from authentik.root.celery import CELERY_APP | ||||
| from authentik.tasks.models import Task | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| CACHE_KEY_OUTPOST_DOWN = "goauthentik.io/outposts/teardown/%s" | ||||
| @ -83,8 +80,8 @@ def controller_for_outpost(outpost: Outpost) -> type[BaseController] | None: | ||||
|     return None | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task() | ||||
| def outpost_service_connection_state(connection_pk: Any): | ||||
| @actor(description=_("Update cached state of service connection.")) | ||||
| def outpost_service_connection_monitor(connection_pk: Any): | ||||
|     """Update cached state of a service connection""" | ||||
|     connection: OutpostServiceConnection = ( | ||||
|         OutpostServiceConnection.objects.filter(pk=connection_pk).select_subclasses().first() | ||||
| @ -108,37 +105,11 @@ def outpost_service_connection_state(connection_pk: Any): | ||||
|     cache.set(connection.state_key, state, timeout=None) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task( | ||||
|     bind=True, | ||||
|     base=SystemTask, | ||||
|     throws=(DatabaseError, ProgrammingError, InternalError), | ||||
| ) | ||||
| @prefill_task | ||||
| def outpost_service_connection_monitor(self: SystemTask): | ||||
|     """Regularly check the state of Outpost Service Connections""" | ||||
|     connections = OutpostServiceConnection.objects.all() | ||||
|     for connection in connections.iterator(): | ||||
|         outpost_service_connection_state.delay(connection.pk) | ||||
|     self.set_status( | ||||
|         TaskStatus.SUCCESSFUL, | ||||
|         f"Successfully updated {len(connections)} connections.", | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task( | ||||
|     throws=(DatabaseError, ProgrammingError, InternalError), | ||||
| ) | ||||
| def outpost_controller_all(): | ||||
|     """Launch Controller for all Outposts which support it""" | ||||
|     for outpost in Outpost.objects.exclude(service_connection=None): | ||||
|         outpost_controller.delay(outpost.pk.hex, "up", from_cache=False) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(bind=True, base=SystemTask) | ||||
| def outpost_controller( | ||||
|     self: SystemTask, outpost_pk: str, action: str = "up", from_cache: bool = False | ||||
| ): | ||||
| @actor(description=_("Create/update/monitor/delete the deployment of an Outpost.")) | ||||
| def outpost_controller(outpost_pk: str, action: str = "up", from_cache: bool = False): | ||||
|     """Create/update/monitor/delete the deployment of an Outpost""" | ||||
|     self: Task = CurrentTask.get_task() | ||||
|     self.set_uid(outpost_pk) | ||||
|     logs = [] | ||||
|     if from_cache: | ||||
|         outpost: Outpost = cache.get(CACHE_KEY_OUTPOST_DOWN % outpost_pk) | ||||
| @ -159,125 +130,65 @@ def outpost_controller( | ||||
|             logs = getattr(controller, f"{action}_with_logs")() | ||||
|             LOGGER.debug("-----------------Outpost Controller logs end-------------------") | ||||
|     except (ControllerException, ServiceConnectionInvalid) as exc: | ||||
|         self.set_error(exc) | ||||
|         self.error(exc) | ||||
|     else: | ||||
|         if from_cache: | ||||
|             cache.delete(CACHE_KEY_OUTPOST_DOWN % outpost_pk) | ||||
|         self.set_status(TaskStatus.SUCCESSFUL, *logs) | ||||
|         self.logs(logs) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task(bind=True, base=SystemTask) | ||||
| @prefill_task | ||||
| def outpost_token_ensurer(self: SystemTask): | ||||
|     """Periodically ensure that all Outposts have valid Service Accounts | ||||
|     and Tokens""" | ||||
| @actor(description=_("Ensure that all Outposts have valid Service Accounts and Tokens.")) | ||||
| def outpost_token_ensurer(): | ||||
|     """ | ||||
|     Periodically ensure that all Outposts have valid Service Accounts and Tokens | ||||
|     """ | ||||
|     self: Task = CurrentTask.get_task() | ||||
|     all_outposts = Outpost.objects.all() | ||||
|     for outpost in all_outposts: | ||||
|         _ = outpost.token | ||||
|         outpost.build_user_permissions(outpost.user) | ||||
|     self.set_status( | ||||
|         TaskStatus.SUCCESSFUL, | ||||
|         f"Successfully checked {len(all_outposts)} Outposts.", | ||||
|     ) | ||||
|     self.info(f"Successfully checked {len(all_outposts)} Outposts.") | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task() | ||||
| def outpost_post_save(model_class: str, model_pk: Any): | ||||
|     """If an Outpost is saved, Ensure that token is created/updated | ||||
|  | ||||
|     If an OutpostModel, or a model that is somehow connected to an OutpostModel is saved, | ||||
|     we send a message down the relevant OutpostModels WS connection to trigger an update""" | ||||
|     model: Model = path_to_class(model_class) | ||||
|     try: | ||||
|         instance = model.objects.get(pk=model_pk) | ||||
|     except model.DoesNotExist: | ||||
|         LOGGER.warning("Model does not exist", model=model, pk=model_pk) | ||||
| @actor(description=_("Send update to outpost")) | ||||
| def outpost_send_update(pk: Any): | ||||
|     """Update outpost instance""" | ||||
|     outpost = Outpost.objects.filter(pk=pk).first() | ||||
|     if not outpost: | ||||
|         return | ||||
|  | ||||
|     if isinstance(instance, Outpost): | ||||
|         LOGGER.debug("Trigger reconcile for outpost", instance=instance) | ||||
|         outpost_controller.delay(str(instance.pk)) | ||||
|  | ||||
|     if isinstance(instance, OutpostModel | Outpost): | ||||
|         LOGGER.debug("triggering outpost update from outpostmodel/outpost", instance=instance) | ||||
|         outpost_send_update(instance) | ||||
|  | ||||
|     if isinstance(instance, OutpostServiceConnection): | ||||
|         LOGGER.debug("triggering ServiceConnection state update", instance=instance) | ||||
|         outpost_service_connection_state.delay(str(instance.pk)) | ||||
|  | ||||
|     for field in instance._meta.get_fields(): | ||||
|         # Each field is checked if it has a `related_model` attribute (when ForeginKeys or M2Ms) | ||||
|         # are used, and if it has a value | ||||
|         if not hasattr(field, "related_model"): | ||||
|             continue | ||||
|         if not field.related_model: | ||||
|             continue | ||||
|         if not issubclass(field.related_model, OutpostModel): | ||||
|             continue | ||||
|  | ||||
|         field_name = f"{field.name}_set" | ||||
|         if not hasattr(instance, field_name): | ||||
|             continue | ||||
|  | ||||
|         LOGGER.debug("triggering outpost update from field", field=field.name) | ||||
|         # Because the Outpost Model has an M2M to Provider, | ||||
|         # we have to iterate over the entire QS | ||||
|         for reverse in getattr(instance, field_name).all(): | ||||
|             outpost_send_update(reverse) | ||||
|  | ||||
|  | ||||
| def outpost_send_update(model_instance: Model): | ||||
|     """Send outpost update to all registered outposts, regardless to which authentik | ||||
|     instance they are connected""" | ||||
|     channel_layer = get_channel_layer() | ||||
|     if isinstance(model_instance, OutpostModel): | ||||
|         for outpost in model_instance.outpost_set.all(): | ||||
|             _outpost_single_update(outpost, channel_layer) | ||||
|     elif isinstance(model_instance, Outpost): | ||||
|         _outpost_single_update(model_instance, channel_layer) | ||||
|  | ||||
|  | ||||
| def _outpost_single_update(outpost: Outpost, layer=None): | ||||
|     """Update outpost instances connected to a single outpost""" | ||||
|     # Ensure token again, because this function is called when anything related to an | ||||
|     # OutpostModel is saved, so we can be sure permissions are right | ||||
|     _ = outpost.token | ||||
|     outpost.build_user_permissions(outpost.user) | ||||
|     if not layer:  # pragma: no cover | ||||
|         layer = get_channel_layer() | ||||
|     layer = get_channel_layer() | ||||
|     group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} | ||||
|     LOGGER.debug("sending update", channel=group, outpost=outpost) | ||||
|     async_to_sync(layer.group_send)(group, {"type": "event.update"}) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task( | ||||
|     base=SystemTask, | ||||
|     bind=True, | ||||
| ) | ||||
| def outpost_connection_discovery(self: SystemTask): | ||||
| @actor(description=_("Checks the local environment and create Service connections.")) | ||||
| def outpost_connection_discovery(): | ||||
|     """Checks the local environment and create Service connections.""" | ||||
|     messages = [] | ||||
|     self: Task = CurrentTask.get_task() | ||||
|     if not CONFIG.get_bool("outposts.discover"): | ||||
|         messages.append("Outpost integration discovery is disabled") | ||||
|         self.set_status(TaskStatus.SUCCESSFUL, *messages) | ||||
|         self.info("Outpost integration discovery is disabled") | ||||
|         return | ||||
|     # Explicitly check against token filename, as that's | ||||
|     # only present when the integration is enabled | ||||
|     if Path(SERVICE_TOKEN_FILENAME).exists(): | ||||
|         messages.append("Detected in-cluster Kubernetes Config") | ||||
|         self.info("Detected in-cluster Kubernetes Config") | ||||
|         if not KubernetesServiceConnection.objects.filter(local=True).exists(): | ||||
|             messages.append("Created Service Connection for in-cluster") | ||||
|             self.info("Created Service Connection for in-cluster") | ||||
|             KubernetesServiceConnection.objects.create( | ||||
|                 name="Local Kubernetes Cluster", local=True, kubeconfig={} | ||||
|             ) | ||||
|     # For development, check for the existence of a kubeconfig file | ||||
|     kubeconfig_path = Path(KUBE_CONFIG_DEFAULT_LOCATION).expanduser() | ||||
|     if kubeconfig_path.exists(): | ||||
|         messages.append("Detected kubeconfig") | ||||
|         self.info("Detected kubeconfig") | ||||
|         kubeconfig_local_name = f"k8s-{gethostname()}" | ||||
|         if not KubernetesServiceConnection.objects.filter(name=kubeconfig_local_name).exists(): | ||||
|             messages.append("Creating kubeconfig Service Connection") | ||||
|             self.info("Creating kubeconfig Service Connection") | ||||
|             with kubeconfig_path.open("r", encoding="utf8") as _kubeconfig: | ||||
|                 KubernetesServiceConnection.objects.create( | ||||
|                     name=kubeconfig_local_name, | ||||
| @ -286,20 +197,18 @@ def outpost_connection_discovery(self: SystemTask): | ||||
|     unix_socket_path = urlparse(DEFAULT_UNIX_SOCKET).path | ||||
|     socket = Path(unix_socket_path) | ||||
|     if socket.exists() and access(socket, R_OK): | ||||
|         messages.append("Detected local docker socket") | ||||
|         self.info("Detected local docker socket") | ||||
|         if len(DockerServiceConnection.objects.filter(local=True)) == 0: | ||||
|             messages.append("Created Service Connection for docker") | ||||
|             self.info("Created Service Connection for docker") | ||||
|             DockerServiceConnection.objects.create( | ||||
|                 name="Local Docker connection", | ||||
|                 local=True, | ||||
|                 url=unix_socket_path, | ||||
|             ) | ||||
|     self.set_status(TaskStatus.SUCCESSFUL, *messages) | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task() | ||||
| @actor(description=_("Terminate session on all outposts.")) | ||||
| def outpost_session_end(session_id: str): | ||||
|     """Update outpost instances connected to a single outpost""" | ||||
|     layer = get_channel_layer() | ||||
|     hashed_session_id = hash_session_key(session_id) | ||||
|     for outpost in Outpost.objects.all(): | ||||
|  | ||||
| @ -37,6 +37,7 @@ class OutpostTests(TestCase): | ||||
|  | ||||
|         # We add a provider, user should only have access to outpost and provider | ||||
|         outpost.providers.add(provider) | ||||
|         provider.refresh_from_db() | ||||
|         permissions = UserObjectPermission.objects.filter(user=outpost.user).order_by( | ||||
|             "content_type__model" | ||||
|         ) | ||||
|  | ||||
| @ -15,6 +15,7 @@ class AuthentikProviderProxyConfig(ManagedAppConfig): | ||||
|     def proxy_set_defaults(self): | ||||
|         from authentik.providers.proxy.models import ProxyProvider | ||||
|  | ||||
|         # TODO: figure out if this can be in pre_save + post_save signals | ||||
|         for provider in ProxyProvider.objects.all(): | ||||
|             provider.set_oauth_defaults() | ||||
|             provider.save() | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	