Compare commits
	
		
			153 Commits
		
	
	
		
			smusali/ev
			...
			version/20
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 1c03cfa906 | |||
| e2dbab5bca | |||
| 3a6c42fefb | |||
| 6bb180f94e | |||
| 03dea17519 | |||
| 49d83f11bd | |||
| 5f0af81e4d | |||
| 63591e1710 | |||
| 6503a7b048 | |||
| 7e244e0679 | |||
| c1998bf3f2 | |||
| 83372618a8 | |||
| 89a876e141 | |||
| 26d6e8bc5c | |||
| d9dc373170 | |||
| 4ec37c5239 | |||
| a9cfa6fe35 | |||
| 5ac5084149 | |||
| eda38a30b1 | |||
| 9b84bf7174 | |||
| f74549be6d | |||
| 76f4d7fb0a | |||
| d1cf1dd083 | |||
| 2835fbd390 | |||
| 76ad2c8925 | |||
| 2270629fdc | |||
| 43a629efc1 | |||
| 4044e52403 | |||
| aa7c846467 | |||
| 8ab7f4073b | |||
| a05856c2ef | |||
| 9e9154e04a | |||
| 32549066c0 | |||
| 5ed3e879a2 | |||
| 4e4923ad0e | |||
| 0302d147e9 | |||
| 8256f1897d | |||
| 16d321835d | |||
| f34612efe6 | |||
| e82f147130 | |||
| 0ea6ad8eea | |||
| f731443220 | |||
| b70a66cde5 | |||
| b733dbbcb0 | |||
| e34d4c0669 | |||
| 310983a4d0 | |||
| 47b0fc86f7 | |||
| b6e961b1f3 | |||
| 874d7ff320 | |||
| e4a5bc9df6 | |||
| 318e0cf9f8 | |||
| bd0815d894 | |||
| af35ecfe66 | |||
| 0c05cd64bb | |||
| cb80b76490 | |||
| 061d4bc758 | |||
| 8ff27f69e1 | |||
| 045cd98276 | |||
| b520843984 | |||
| 92216e4ea8 | |||
| babaeb2d0c | |||
| 52b8f24b75 | |||
| 464addfc8d | |||
| 8df73c2f6f | |||
| 9ab3971e63 | |||
| 09888cb89f | |||
| 2abcc9ce8f | |||
| 5b0e92f034 | |||
| a3bfb3d25c | |||
| 2c1df6702c | |||
| b999e23d27 | |||
| e0db9f3ea1 | |||
| dcc3ca664a | |||
| 7d37e3f668 | |||
| e48f6bbec4 | |||
| d27caaabc3 | |||
| 0dee706a87 | |||
| 7d527beea8 | |||
| 4733778460 | |||
| c048f4a356 | |||
| 65e245c003 | |||
| 600d59ff58 | |||
| 703628f354 | |||
| 693de081ef | |||
| f367249bab | |||
| 2841db082c | |||
| ce24f974aa | |||
| 1f93e6fd3f | |||
| 7dfde9029f | |||
| f5d62b828b | |||
| 703eb682b7 | |||
| 5cae3192b1 | |||
| 83e143032d | |||
| e0e7cc24da | |||
| 8bc746d577 | |||
| a84f403e79 | |||
| e4f4482d2a | |||
| 844b4e96cd | |||
| f3b4e03243 | |||
| 4f5e2a438e | |||
| 32c980e29e | |||
| bd29392825 | |||
| 9756432876 | |||
| 8b2d1a9b21 | |||
| adbd97323c | |||
| 77a8b2d751 | |||
| 08c850938b | |||
| 7db598c04e | |||
| 1ef224f5fd | |||
| b01c48698d | |||
| 1546fa276a | |||
| f50bd74b46 | |||
| 414a5c36c8 | |||
| c4455b6915 | |||
| 9013caeab4 | |||
| 40a1e5a9b2 | |||
| 4dadcc1dfd | |||
| 0b8678f7ee | |||
| aa8dc94a97 | |||
| 20996e994e | |||
| db17f04830 | |||
| b99fca62d8 | |||
| 8818ce3306 | |||
| 25d3f2e06e | |||
| 1537682026 | |||
| ebd05be2c4 | |||
| c90792d876 | |||
| b92630804f | |||
| 1afd5ef95a | |||
| e5cc2c6d98 | |||
| 84fdd4d737 | |||
| 5fe2772567 | |||
| d2f9b66424 | |||
| c9b39f2eba | |||
| 2ecc2119fc | |||
| 49b7ebdc53 | |||
| 70f72c524d | |||
| 87e0ac743a | |||
| b5b8b0e9cd | |||
| d10b358767 | |||
| 0887fa8fde | |||
| 799dd48861 | |||
| 919d1f349f | |||
| a36b6e8315 | |||
| 69f9dfc9f6 | |||
| 27efe68f1c | |||
| e9223618ba | |||
| e9672a5285 | |||
| 7d724d9931 | |||
| edcc6b2031 | |||
| a71948c9b7 | |||
| a395e347df | |||
| f4b336a974 | 
| @ -1,12 +1,20 @@ | |||||||
| [bumpversion] | [bumpversion] | ||||||
| current_version = 2023.10.7 | current_version = 2024.2.4 | ||||||
| tag = True | tag = True | ||||||
| commit = True | commit = True | ||||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) | parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | ||||||
| serialize = {major}.{minor}.{patch} | serialize =  | ||||||
|  | 	{major}.{minor}.{patch}-{rc_t}{rc_n} | ||||||
|  | 	{major}.{minor}.{patch} | ||||||
| message = release: {new_version} | message = release: {new_version} | ||||||
| tag_name = version/{new_version} | tag_name = version/{new_version} | ||||||
|  |  | ||||||
|  | [bumpversion:part:rc_t] | ||||||
|  | values =  | ||||||
|  | 	rc | ||||||
|  | 	final | ||||||
|  | optional_value = final | ||||||
|  |  | ||||||
| [bumpversion:file:pyproject.toml] | [bumpversion:file:pyproject.toml] | ||||||
|  |  | ||||||
| [bumpversion:file:docker-compose.yml] | [bumpversion:file:docker-compose.yml] | ||||||
|  | |||||||
| @ -9,9 +9,6 @@ inputs: | |||||||
| runs: | runs: | ||||||
|   using: "composite" |   using: "composite" | ||||||
|   steps: |   steps: | ||||||
|     - name: Generate config |  | ||||||
|       id: ev |  | ||||||
|       uses: ./.github/actions/docker-push-variables |  | ||||||
|     - name: Find Comment |     - name: Find Comment | ||||||
|       uses: peter-evans/find-comment@v2 |       uses: peter-evans/find-comment@v2 | ||||||
|       id: fc |       id: fc | ||||||
|  | |||||||
							
								
								
									
										73
									
								
								.github/actions/docker-push-variables/action.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										73
									
								
								.github/actions/docker-push-variables/action.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,64 +1,47 @@ | |||||||
|  | --- | ||||||
| name: "Prepare docker environment variables" | name: "Prepare docker environment variables" | ||||||
| description: "Prepare docker environment variables" | description: "Prepare docker environment variables" | ||||||
|  |  | ||||||
|  | inputs: | ||||||
|  |   image-name: | ||||||
|  |     required: true | ||||||
|  |     description: "Docker image prefix" | ||||||
|  |   image-arch: | ||||||
|  |     required: false | ||||||
|  |     description: "Docker image arch" | ||||||
|  |  | ||||||
| outputs: | outputs: | ||||||
|   shouldBuild: |   shouldBuild: | ||||||
|     description: "Whether to build image or not" |     description: "Whether to build image or not" | ||||||
|     value: ${{ steps.ev.outputs.shouldBuild }} |     value: ${{ steps.ev.outputs.shouldBuild }} | ||||||
|   branchName: |  | ||||||
|     description: "Branch name" |  | ||||||
|     value: ${{ steps.ev.outputs.branchName }} |  | ||||||
|   branchNameContainer: |  | ||||||
|     description: "Branch name (for containers)" |  | ||||||
|     value: ${{ steps.ev.outputs.branchNameContainer }} |  | ||||||
|   timestamp: |  | ||||||
|     description: "Timestamp" |  | ||||||
|     value: ${{ steps.ev.outputs.timestamp }} |  | ||||||
|   sha: |   sha: | ||||||
|     description: "sha" |     description: "sha" | ||||||
|     value: ${{ steps.ev.outputs.sha }} |     value: ${{ steps.ev.outputs.sha }} | ||||||
|   shortHash: |  | ||||||
|     description: "shortHash" |  | ||||||
|     value: ${{ steps.ev.outputs.shortHash }} |  | ||||||
|   version: |   version: | ||||||
|     description: "version" |     description: "Version" | ||||||
|     value: ${{ steps.ev.outputs.version }} |     value: ${{ steps.ev.outputs.version }} | ||||||
|   versionFamily: |   prerelease: | ||||||
|     description: "versionFamily" |     description: "Prerelease" | ||||||
|     value: ${{ steps.ev.outputs.versionFamily }} |     value: ${{ steps.ev.outputs.prerelease }} | ||||||
|  |  | ||||||
|  |   imageTags: | ||||||
|  |     description: "Docker image tags" | ||||||
|  |     value: ${{ steps.ev.outputs.imageTags }} | ||||||
|  |   imageMainTag: | ||||||
|  |     description: "Docker image main tag" | ||||||
|  |     value: ${{ steps.ev.outputs.imageMainTag }} | ||||||
|  |  | ||||||
| runs: | runs: | ||||||
|   using: "composite" |   using: "composite" | ||||||
|   steps: |   steps: | ||||||
|     - name: Generate config |     - name: Generate config | ||||||
|       id: ev |       id: ev | ||||||
|       shell: python |       shell: bash | ||||||
|  |       env: | ||||||
|  |         IMAGE_NAME: ${{ inputs.image-name }} | ||||||
|  |         IMAGE_ARCH: ${{ inputs.image-arch }} | ||||||
|  |         PR_HEAD_SHA: ${{ github.event.pull_request.head.sha }} | ||||||
|       run: | |       run: | | ||||||
|         """Helper script to get the actual branch name, docker safe""" |         python3 ${{ github.action_path }}/push_vars.py | ||||||
|         import configparser |  | ||||||
|         import os |  | ||||||
|         from time import time |  | ||||||
|  |  | ||||||
|         parser = configparser.ConfigParser() |  | ||||||
|         parser.read(".bumpversion.cfg") |  | ||||||
|  |  | ||||||
|         branch_name = os.environ["GITHUB_REF"] |  | ||||||
|         if os.environ.get("GITHUB_HEAD_REF", "") != "": |  | ||||||
|             branch_name = os.environ["GITHUB_HEAD_REF"] |  | ||||||
|  |  | ||||||
|         should_build = str(os.environ.get("DOCKER_USERNAME", "") != "").lower() |  | ||||||
|         version = parser.get("bumpversion", "current_version") |  | ||||||
|         version_family = ".".join(version.split(".")[:-1]) |  | ||||||
|         safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-") |  | ||||||
|  |  | ||||||
|         sha = os.environ["GITHUB_SHA"] if not "${{ github.event.pull_request.head.sha }}" else "${{ github.event.pull_request.head.sha }}" |  | ||||||
|  |  | ||||||
|         with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output: |  | ||||||
|             print("branchName=%s" % branch_name, file=_output) |  | ||||||
|             print("branchNameContainer=%s" % safe_branch_name, file=_output) |  | ||||||
|             print("timestamp=%s" % int(time()), file=_output) |  | ||||||
|             print("sha=%s" % sha, file=_output) |  | ||||||
|             print("shortHash=%s" % sha[:7], file=_output) |  | ||||||
|             print("shouldBuild=%s" % should_build, file=_output) |  | ||||||
|             print("version=%s" % version, file=_output) |  | ||||||
|             print("versionFamily=%s" % version_family, file=_output) |  | ||||||
|  | |||||||
							
								
								
									
										62
									
								
								.github/actions/docker-push-variables/push_vars.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								.github/actions/docker-push-variables/push_vars.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,62 @@ | |||||||
|  | """Helper script to get the actual branch name, docker safe""" | ||||||
|  |  | ||||||
|  | import configparser | ||||||
|  | import os | ||||||
|  | from time import time | ||||||
|  |  | ||||||
|  | parser = configparser.ConfigParser() | ||||||
|  | parser.read(".bumpversion.cfg") | ||||||
|  |  | ||||||
|  | should_build = str(os.environ.get("DOCKER_USERNAME", None) is not None).lower() | ||||||
|  |  | ||||||
|  | branch_name = os.environ["GITHUB_REF"] | ||||||
|  | if os.environ.get("GITHUB_HEAD_REF", "") != "": | ||||||
|  |     branch_name = os.environ["GITHUB_HEAD_REF"] | ||||||
|  | safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-") | ||||||
|  |  | ||||||
|  | image_names = os.getenv("IMAGE_NAME").split(",") | ||||||
|  | image_arch = os.getenv("IMAGE_ARCH") or None | ||||||
|  |  | ||||||
|  | is_pull_request = bool(os.getenv("PR_HEAD_SHA")) | ||||||
|  | is_release = "dev" not in image_names[0] | ||||||
|  |  | ||||||
|  | sha = os.environ["GITHUB_SHA"] if not is_pull_request else os.getenv("PR_HEAD_SHA") | ||||||
|  |  | ||||||
|  | # 2042.1.0 or 2042.1.0-rc1 | ||||||
|  | version = parser.get("bumpversion", "current_version") | ||||||
|  | # 2042.1 | ||||||
|  | version_family = ".".join(version.split("-", 1)[0].split(".")[:-1]) | ||||||
|  | prerelease = "-" in version | ||||||
|  |  | ||||||
|  | image_tags = [] | ||||||
|  | if is_release: | ||||||
|  |     for name in image_names: | ||||||
|  |         image_tags += [ | ||||||
|  |             f"{name}:{version}", | ||||||
|  |         ] | ||||||
|  |         if not prerelease: | ||||||
|  |             image_tags += [ | ||||||
|  |                 f"{name}:latest", | ||||||
|  |                 f"{name}:{version_family}", | ||||||
|  |             ] | ||||||
|  | else: | ||||||
|  |     suffix = "" | ||||||
|  |     if image_arch and image_arch != "amd64": | ||||||
|  |         suffix = f"-{image_arch}" | ||||||
|  |     for name in image_names: | ||||||
|  |         image_tags += [ | ||||||
|  |             f"{name}:gh-{sha}{suffix}",  # Used for ArgoCD and PR comments | ||||||
|  |             f"{name}:gh-{safe_branch_name}{suffix}",  # For convenience | ||||||
|  |             f"{name}:gh-{safe_branch_name}-{int(time())}-{sha[:7]}{suffix}",  # Use by FluxCD | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  | image_main_tag = image_tags[0] | ||||||
|  | image_tags_rendered = ",".join(image_tags) | ||||||
|  |  | ||||||
|  | with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output: | ||||||
|  |     print("shouldBuild=%s" % should_build, file=_output) | ||||||
|  |     print("sha=%s" % sha, file=_output) | ||||||
|  |     print("version=%s" % version, file=_output) | ||||||
|  |     print("prerelease=%s" % prerelease, file=_output) | ||||||
|  |     print("imageTags=%s" % image_tags_rendered, file=_output) | ||||||
|  |     print("imageMainTag=%s" % image_main_tag, file=_output) | ||||||
							
								
								
									
										7
									
								
								.github/actions/docker-push-variables/test.sh
									
									
									
									
										vendored
									
									
										Executable file
									
								
							
							
						
						
									
										7
									
								
								.github/actions/docker-push-variables/test.sh
									
									
									
									
										vendored
									
									
										Executable file
									
								
							| @ -0,0 +1,7 @@ | |||||||
|  | #!/bin/bash -x | ||||||
|  | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) | ||||||
|  | GITHUB_OUTPUT=/dev/stdout \ | ||||||
|  |     GITHUB_REF=ref \ | ||||||
|  |     GITHUB_SHA=sha \ | ||||||
|  |     IMAGE_NAME=ghcr.io/goauthentik/server,beryju/authentik \ | ||||||
|  |     python $SCRIPT_DIR/push_vars.py | ||||||
							
								
								
									
										1
									
								
								.github/codespell-words.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/codespell-words.txt
									
									
									
									
										vendored
									
									
								
							| @ -3,3 +3,4 @@ keypairs | |||||||
| hass | hass | ||||||
| warmup | warmup | ||||||
| ontext | ontext | ||||||
|  | singed | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								.github/pull_request_template.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/pull_request_template.md
									
									
									
									
										vendored
									
									
								
							| @ -27,7 +27,6 @@ If an API change has been made | |||||||
| If changes to the frontend have been made | If changes to the frontend have been made | ||||||
|  |  | ||||||
| -   [ ] The code has been formatted (`make web`) | -   [ ] The code has been formatted (`make web`) | ||||||
| -   [ ] The translation files have been updated (`make i18n-extract`) |  | ||||||
|  |  | ||||||
| If applicable | If applicable | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										81
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										81
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,3 +1,4 @@ | |||||||
|  | --- | ||||||
| name: authentik-ci-main | name: authentik-ci-main | ||||||
|  |  | ||||||
| on: | on: | ||||||
| @ -7,7 +8,7 @@ on: | |||||||
|       - next |       - next | ||||||
|       - version-* |       - version-* | ||||||
|     paths-ignore: |     paths-ignore: | ||||||
|       - website |       - website/** | ||||||
|   pull_request: |   pull_request: | ||||||
|     branches: |     branches: | ||||||
|       - main |       - main | ||||||
| @ -29,7 +30,7 @@ jobs: | |||||||
|           - codespell |           - codespell | ||||||
|           - isort |           - isort | ||||||
|           - pending-migrations |           - pending-migrations | ||||||
|           - pylint |           # - pylint | ||||||
|           - pyright |           - pyright | ||||||
|           - ruff |           - ruff | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
| @ -69,7 +70,7 @@ jobs: | |||||||
|           cp authentik/lib/default.yml local.env.yml |           cp authentik/lib/default.yml local.env.yml | ||||||
|           cp -R .github .. |           cp -R .github .. | ||||||
|           cp -R scripts .. |           cp -R scripts .. | ||||||
|           git checkout version/$(python -c "from authentik import __version__; print(__version__)") |           git checkout $(git tag --sort=version:refname | grep '^version/' | grep -vE -- '-rc[0-9]+$' | tail -n1) | ||||||
|           rm -rf .github/ scripts/ |           rm -rf .github/ scripts/ | ||||||
|           mv ../.github ../scripts . |           mv ../.github ../scripts . | ||||||
|       - name: Setup authentik env (stable) |       - name: Setup authentik env (stable) | ||||||
| @ -134,7 +135,7 @@ jobs: | |||||||
|       - name: Setup authentik env |       - name: Setup authentik env | ||||||
|         uses: ./.github/actions/setup |         uses: ./.github/actions/setup | ||||||
|       - name: Create k8s Kind Cluster |       - name: Create k8s Kind Cluster | ||||||
|         uses: helm/kind-action@v1.8.0 |         uses: helm/kind-action@v1.9.0 | ||||||
|       - name: run integration |       - name: run integration | ||||||
|         run: | |         run: | | ||||||
|           poetry run coverage run manage.py test tests/integration |           poetry run coverage run manage.py test tests/integration | ||||||
| @ -206,6 +207,12 @@ jobs: | |||||||
|     steps: |     steps: | ||||||
|       - run: echo mark |       - run: echo mark | ||||||
|   build: |   build: | ||||||
|  |     strategy: | ||||||
|  |       fail-fast: false | ||||||
|  |       matrix: | ||||||
|  |         arch: | ||||||
|  |           - amd64 | ||||||
|  |           - arm64 | ||||||
|     needs: ci-core-mark |     needs: ci-core-mark | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     permissions: |     permissions: | ||||||
| @ -225,9 +232,12 @@ jobs: | |||||||
|         id: ev |         id: ev | ||||||
|         env: |         env: | ||||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} |           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||||
|  |         with: | ||||||
|  |           image-name: ghcr.io/goauthentik/dev-server | ||||||
|  |           image-arch: ${{ matrix.arch }} | ||||||
|       - name: Login to Container Registry |       - name: Login to Container Registry | ||||||
|         uses: docker/login-action@v3 |  | ||||||
|         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} |         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||||
|  |         uses: docker/login-action@v3 | ||||||
|         with: |         with: | ||||||
|           registry: ghcr.io |           registry: ghcr.io | ||||||
|           username: ${{ github.repository_owner }} |           username: ${{ github.repository_owner }} | ||||||
| @ -241,69 +251,16 @@ jobs: | |||||||
|           secrets: | |           secrets: | | ||||||
|             GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} |             GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} | ||||||
|             GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }} |             GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }} | ||||||
|  |           tags: ${{ steps.ev.outputs.imageTags }} | ||||||
|           push: ${{ steps.ev.outputs.shouldBuild == 'true' }} |           push: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||||
|           tags: | |  | ||||||
|             ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.branchNameContainer }} |  | ||||||
|             ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.sha }} |  | ||||||
|             ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }} |  | ||||||
|           build-args: | |           build-args: | | ||||||
|             GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} |             GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} | ||||||
|             VERSION=${{ steps.ev.outputs.version }} |  | ||||||
|             VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} |  | ||||||
|           cache-from: type=gha |  | ||||||
|           cache-to: type=gha,mode=max |  | ||||||
|   build-arm64: |  | ||||||
|     needs: ci-core-mark |  | ||||||
|     runs-on: ubuntu-latest |  | ||||||
|     permissions: |  | ||||||
|       # Needed to upload contianer images to ghcr.io |  | ||||||
|       packages: write |  | ||||||
|     timeout-minutes: 120 |  | ||||||
|     steps: |  | ||||||
|       - uses: actions/checkout@v4 |  | ||||||
|         with: |  | ||||||
|           ref: ${{ github.event.pull_request.head.sha }} |  | ||||||
|       - name: Set up QEMU |  | ||||||
|         uses: docker/setup-qemu-action@v3.0.0 |  | ||||||
|       - name: Set up Docker Buildx |  | ||||||
|         uses: docker/setup-buildx-action@v3 |  | ||||||
|       - name: prepare variables |  | ||||||
|         uses: ./.github/actions/docker-push-variables |  | ||||||
|         id: ev |  | ||||||
|         env: |  | ||||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} |  | ||||||
|       - name: Login to Container Registry |  | ||||||
|         uses: docker/login-action@v3 |  | ||||||
|         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} |  | ||||||
|         with: |  | ||||||
|           registry: ghcr.io |  | ||||||
|           username: ${{ github.repository_owner }} |  | ||||||
|           password: ${{ secrets.GITHUB_TOKEN }} |  | ||||||
|       - name: generate ts client |  | ||||||
|         run: make gen-client-ts |  | ||||||
|       - name: Build Docker Image |  | ||||||
|         uses: docker/build-push-action@v5 |  | ||||||
|         with: |  | ||||||
|           context: . |  | ||||||
|           secrets: | |  | ||||||
|             GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} |  | ||||||
|             GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }} |  | ||||||
|           push: ${{ steps.ev.outputs.shouldBuild == 'true' }} |  | ||||||
|           tags: | |  | ||||||
|             ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.branchNameContainer }}-arm64 |  | ||||||
|             ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.sha }}-arm64 |  | ||||||
|             ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}-arm64 |  | ||||||
|           build-args: | |  | ||||||
|             GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} |  | ||||||
|             VERSION=${{ steps.ev.outputs.version }} |  | ||||||
|             VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} |  | ||||||
|           platforms: linux/arm64 |  | ||||||
|           cache-from: type=gha |           cache-from: type=gha | ||||||
|           cache-to: type=gha,mode=max |           cache-to: type=gha,mode=max | ||||||
|  |           platforms: linux/${{ matrix.arch }} | ||||||
|   pr-comment: |   pr-comment: | ||||||
|     needs: |     needs: | ||||||
|       - build |       - build | ||||||
|       - build-arm64 |  | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     if: ${{ github.event_name == 'pull_request' }} |     if: ${{ github.event_name == 'pull_request' }} | ||||||
|     permissions: |     permissions: | ||||||
| @ -319,7 +276,9 @@ jobs: | |||||||
|         id: ev |         id: ev | ||||||
|         env: |         env: | ||||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} |           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||||
|  |         with: | ||||||
|  |           image-name: ghcr.io/goauthentik/dev-server | ||||||
|       - name: Comment on PR |       - name: Comment on PR | ||||||
|         uses: ./.github/actions/comment-pr-instructions |         uses: ./.github/actions/comment-pr-instructions | ||||||
|         with: |         with: | ||||||
|           tag: gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }} |           tag: gh-${{ steps.ev.outputs.imageMainTag }} | ||||||
|  | |||||||
							
								
								
									
										15
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										15
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,3 +1,4 @@ | |||||||
|  | --- | ||||||
| name: authentik-ci-outpost | name: authentik-ci-outpost | ||||||
|  |  | ||||||
| on: | on: | ||||||
| @ -28,7 +29,7 @@ jobs: | |||||||
|       - name: Generate API |       - name: Generate API | ||||||
|         run: make gen-client-go |         run: make gen-client-go | ||||||
|       - name: golangci-lint |       - name: golangci-lint | ||||||
|         uses: golangci/golangci-lint-action@v3 |         uses: golangci/golangci-lint-action@v4 | ||||||
|         with: |         with: | ||||||
|           version: v1.54.2 |           version: v1.54.2 | ||||||
|           args: --timeout 5000s --verbose |           args: --timeout 5000s --verbose | ||||||
| @ -83,9 +84,11 @@ jobs: | |||||||
|         id: ev |         id: ev | ||||||
|         env: |         env: | ||||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} |           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||||
|  |         with: | ||||||
|  |           image-name: ghcr.io/goauthentik/dev-${{ matrix.type }} | ||||||
|       - name: Login to Container Registry |       - name: Login to Container Registry | ||||||
|         uses: docker/login-action@v3 |  | ||||||
|         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} |         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||||
|  |         uses: docker/login-action@v3 | ||||||
|         with: |         with: | ||||||
|           registry: ghcr.io |           registry: ghcr.io | ||||||
|           username: ${{ github.repository_owner }} |           username: ${{ github.repository_owner }} | ||||||
| @ -95,15 +98,11 @@ jobs: | |||||||
|       - name: Build Docker Image |       - name: Build Docker Image | ||||||
|         uses: docker/build-push-action@v5 |         uses: docker/build-push-action@v5 | ||||||
|         with: |         with: | ||||||
|           push: ${{ steps.ev.outputs.shouldBuild == 'true' }} |           tags: ${{ steps.ev.outputs.imageTags }} | ||||||
|           tags: | |  | ||||||
|             ghcr.io/goauthentik/dev-${{ matrix.type }}:gh-${{ steps.ev.outputs.branchNameContainer }} |  | ||||||
|             ghcr.io/goauthentik/dev-${{ matrix.type }}:gh-${{ steps.ev.outputs.sha }} |  | ||||||
|           file: ${{ matrix.type }}.Dockerfile |           file: ${{ matrix.type }}.Dockerfile | ||||||
|  |           push: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||||
|           build-args: | |           build-args: | | ||||||
|             GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} |             GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} | ||||||
|             VERSION=${{ steps.ev.outputs.version }} |  | ||||||
|             VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} |  | ||||||
|           platforms: linux/amd64,linux/arm64 |           platforms: linux/amd64,linux/arm64 | ||||||
|           context: . |           context: . | ||||||
|           cache-from: type=gha |           cache-from: type=gha | ||||||
|  | |||||||
							
								
								
									
										44
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										44
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,3 +1,4 @@ | |||||||
|  | --- | ||||||
| name: authentik-on-release | name: authentik-on-release | ||||||
|  |  | ||||||
| on: | on: | ||||||
| @ -19,6 +20,10 @@ jobs: | |||||||
|       - name: prepare variables |       - name: prepare variables | ||||||
|         uses: ./.github/actions/docker-push-variables |         uses: ./.github/actions/docker-push-variables | ||||||
|         id: ev |         id: ev | ||||||
|  |         env: | ||||||
|  |           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||||
|  |         with: | ||||||
|  |           image-name: ghcr.io/goauthentik/server,beryju/authentik | ||||||
|       - name: Docker Login Registry |       - name: Docker Login Registry | ||||||
|         uses: docker/login-action@v3 |         uses: docker/login-action@v3 | ||||||
|         with: |         with: | ||||||
| @ -38,21 +43,12 @@ jobs: | |||||||
|         uses: docker/build-push-action@v5 |         uses: docker/build-push-action@v5 | ||||||
|         with: |         with: | ||||||
|           context: . |           context: . | ||||||
|           push: ${{ github.event_name == 'release' }} |           push: true | ||||||
|           secrets: | |           secrets: | | ||||||
|             GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} |             GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} | ||||||
|             GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }} |             GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }} | ||||||
|           tags: | |           tags: ${{ steps.ev.outputs.imageTags }} | ||||||
|             beryju/authentik:${{ steps.ev.outputs.version }}, |  | ||||||
|             beryju/authentik:${{ steps.ev.outputs.versionFamily }}, |  | ||||||
|             beryju/authentik:latest, |  | ||||||
|             ghcr.io/goauthentik/server:${{ steps.ev.outputs.version }}, |  | ||||||
|             ghcr.io/goauthentik/server:${{ steps.ev.outputs.versionFamily }}, |  | ||||||
|             ghcr.io/goauthentik/server:latest |  | ||||||
|           platforms: linux/amd64,linux/arm64 |           platforms: linux/amd64,linux/arm64 | ||||||
|           build-args: | |  | ||||||
|             VERSION=${{ steps.ev.outputs.version }} |  | ||||||
|             VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} |  | ||||||
|   build-outpost: |   build-outpost: | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     permissions: |     permissions: | ||||||
| @ -78,6 +74,10 @@ jobs: | |||||||
|       - name: prepare variables |       - name: prepare variables | ||||||
|         uses: ./.github/actions/docker-push-variables |         uses: ./.github/actions/docker-push-variables | ||||||
|         id: ev |         id: ev | ||||||
|  |         env: | ||||||
|  |           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||||
|  |         with: | ||||||
|  |           image-name: ghcr.io/goauthentik/${{ matrix.type }},beryju/authentik-${{ matrix.type }} | ||||||
|       - name: make empty clients |       - name: make empty clients | ||||||
|         run: | |         run: | | ||||||
|           mkdir -p ./gen-ts-api |           mkdir -p ./gen-ts-api | ||||||
| @ -96,20 +96,11 @@ jobs: | |||||||
|       - name: Build Docker Image |       - name: Build Docker Image | ||||||
|         uses: docker/build-push-action@v5 |         uses: docker/build-push-action@v5 | ||||||
|         with: |         with: | ||||||
|           push: ${{ github.event_name == 'release' }} |           push: true | ||||||
|           tags: | |           tags: ${{ steps.ev.outputs.imageTags }} | ||||||
|             beryju/authentik-${{ matrix.type }}:${{ steps.ev.outputs.version }}, |  | ||||||
|             beryju/authentik-${{ matrix.type }}:${{ steps.ev.outputs.versionFamily }}, |  | ||||||
|             beryju/authentik-${{ matrix.type }}:latest, |  | ||||||
|             ghcr.io/goauthentik/${{ matrix.type }}:${{ steps.ev.outputs.version }}, |  | ||||||
|             ghcr.io/goauthentik/${{ matrix.type }}:${{ steps.ev.outputs.versionFamily }}, |  | ||||||
|             ghcr.io/goauthentik/${{ matrix.type }}:latest |  | ||||||
|           file: ${{ matrix.type }}.Dockerfile |           file: ${{ matrix.type }}.Dockerfile | ||||||
|           platforms: linux/amd64,linux/arm64 |           platforms: linux/amd64,linux/arm64 | ||||||
|           context: . |           context: . | ||||||
|           build-args: | |  | ||||||
|             VERSION=${{ steps.ev.outputs.version }} |  | ||||||
|             VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} |  | ||||||
|   build-outpost-binary: |   build-outpost-binary: | ||||||
|     timeout-minutes: 120 |     timeout-minutes: 120 | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
| @ -181,15 +172,18 @@ jobs: | |||||||
|       - name: prepare variables |       - name: prepare variables | ||||||
|         uses: ./.github/actions/docker-push-variables |         uses: ./.github/actions/docker-push-variables | ||||||
|         id: ev |         id: ev | ||||||
|  |         env: | ||||||
|  |           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||||
|  |         with: | ||||||
|  |           image-name: ghcr.io/goauthentik/server | ||||||
|       - name: Get static files from docker image |       - name: Get static files from docker image | ||||||
|         run: | |         run: | | ||||||
|           docker pull ghcr.io/goauthentik/server:latest |           docker pull ${{ steps.ev.outputs.imageMainTag }} | ||||||
|           container=$(docker container create ghcr.io/goauthentik/server:latest) |           container=$(docker container create ${{ steps.ev.outputs.imageMainTag }}) | ||||||
|           docker cp ${container}:web/ . |           docker cp ${container}:web/ . | ||||||
|       - name: Create a Sentry.io release |       - name: Create a Sentry.io release | ||||||
|         uses: getsentry/action-release@v1 |         uses: getsentry/action-release@v1 | ||||||
|         continue-on-error: true |         continue-on-error: true | ||||||
|         if: ${{ github.event_name == 'release' }} |  | ||||||
|         env: |         env: | ||||||
|           SENTRY_AUTH_TOKEN: ${{ secrets.SENTRY_AUTH_TOKEN }} |           SENTRY_AUTH_TOKEN: ${{ secrets.SENTRY_AUTH_TOKEN }} | ||||||
|           SENTRY_ORG: authentik-security-inc |           SENTRY_ORG: authentik-security-inc | ||||||
|  | |||||||
							
								
								
									
										17
									
								
								.github/workflows/release-tag.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										17
									
								
								.github/workflows/release-tag.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,3 +1,4 @@ | |||||||
|  | --- | ||||||
| name: authentik-on-tag | name: authentik-on-tag | ||||||
|  |  | ||||||
| on: | on: | ||||||
| @ -28,13 +29,13 @@ jobs: | |||||||
|         with: |         with: | ||||||
|           app_id: ${{ secrets.GH_APP_ID }} |           app_id: ${{ secrets.GH_APP_ID }} | ||||||
|           private_key: ${{ secrets.GH_APP_PRIVATE_KEY }} |           private_key: ${{ secrets.GH_APP_PRIVATE_KEY }} | ||||||
|       - name: Extract version number |       - name: prepare variables | ||||||
|         id: get_version |         uses: ./.github/actions/docker-push-variables | ||||||
|         uses: actions/github-script@v7 |         id: ev | ||||||
|  |         env: | ||||||
|  |           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||||
|         with: |         with: | ||||||
|           github-token: ${{ steps.generate_token.outputs.token }} |           image-name: ghcr.io/goauthentik/server | ||||||
|           script: | |  | ||||||
|             return context.payload.ref.replace(/\/refs\/tags\/version\//, ''); |  | ||||||
|       - name: Create Release |       - name: Create Release | ||||||
|         id: create_release |         id: create_release | ||||||
|         uses: actions/create-release@v1.1.4 |         uses: actions/create-release@v1.1.4 | ||||||
| @ -42,6 +43,6 @@ jobs: | |||||||
|           GITHUB_TOKEN: ${{ steps.generate_token.outputs.token }} |           GITHUB_TOKEN: ${{ steps.generate_token.outputs.token }} | ||||||
|         with: |         with: | ||||||
|           tag_name: ${{ github.ref }} |           tag_name: ${{ github.ref }} | ||||||
|           release_name: Release ${{ steps.get_version.outputs.result }} |           release_name: Release ${{ steps.ev.outputs.version }} | ||||||
|           draft: true |           draft: true | ||||||
|           prerelease: false |           prerelease: ${{ steps.ev.outputs.prerelease == 'true' }} | ||||||
|  | |||||||
| @ -1,9 +1,8 @@ | |||||||
| name: authentik-backend-translate-compile | --- | ||||||
|  | name: authentik-backend-translate-extract-compile | ||||||
| on: | on: | ||||||
|   push: |   schedule: | ||||||
|     branches: [main] |     - cron: "0 0 * * *" # every day at midnight | ||||||
|     paths: |  | ||||||
|       - "locale/**" |  | ||||||
|   workflow_dispatch: |   workflow_dispatch: | ||||||
| 
 | 
 | ||||||
| env: | env: | ||||||
| @ -25,16 +24,20 @@ jobs: | |||||||
|           token: ${{ steps.generate_token.outputs.token }} |           token: ${{ steps.generate_token.outputs.token }} | ||||||
|       - name: Setup authentik env |       - name: Setup authentik env | ||||||
|         uses: ./.github/actions/setup |         uses: ./.github/actions/setup | ||||||
|  |       - name: run extract | ||||||
|  |         run: | | ||||||
|  |           poetry run make i18n-extract | ||||||
|       - name: run compile |       - name: run compile | ||||||
|         run: poetry run ak compilemessages |         run: | | ||||||
|  |           poetry run ak compilemessages | ||||||
|  |           make web-check-compile | ||||||
|       - name: Create Pull Request |       - name: Create Pull Request | ||||||
|         uses: peter-evans/create-pull-request@v6 |         uses: peter-evans/create-pull-request@v6 | ||||||
|         id: cpr |  | ||||||
|         with: |         with: | ||||||
|           token: ${{ steps.generate_token.outputs.token }} |           token: ${{ steps.generate_token.outputs.token }} | ||||||
|           branch: compile-backend-translation |           branch: extract-compile-backend-translation | ||||||
|           commit-message: "core: compile backend translations" |           commit-message: "core, web: update translations" | ||||||
|           title: "core: compile backend translations" |           title: "core, web: update translations" | ||||||
|           body: "core: compile backend translations" |           body: "core, web: update translations" | ||||||
|           delete-branch: true |           delete-branch: true | ||||||
|           signoff: true |           signoff: true | ||||||
							
								
								
									
										13
									
								
								Dockerfile
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								Dockerfile
									
									
									
									
									
								
							| @ -37,7 +37,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api | |||||||
| RUN npm run build | RUN npm run build | ||||||
|  |  | ||||||
| # Stage 3: Build go proxy | # Stage 3: Build go proxy | ||||||
| FROM --platform=${BUILDPLATFORM} docker.io/golang:1.21.6-bookworm AS go-builder | FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22.0-bookworm AS go-builder | ||||||
|  |  | ||||||
| ARG TARGETOS | ARG TARGETOS | ||||||
| ARG TARGETARCH | ARG TARGETARCH | ||||||
| @ -83,7 +83,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \ | |||||||
|     /bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" |     /bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" | ||||||
|  |  | ||||||
| # Stage 5: Python dependencies | # Stage 5: Python dependencies | ||||||
| FROM docker.io/python:3.12.1-slim-bookworm AS python-deps | FROM docker.io/python:3.12.2-slim-bookworm AS python-deps | ||||||
|  |  | ||||||
| WORKDIR /ak-root/poetry | WORKDIR /ak-root/poetry | ||||||
|  |  | ||||||
| @ -103,12 +103,13 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \ | |||||||
|     --mount=type=cache,target=/root/.cache/pip \ |     --mount=type=cache,target=/root/.cache/pip \ | ||||||
|     --mount=type=cache,target=/root/.cache/pypoetry \ |     --mount=type=cache,target=/root/.cache/pypoetry \ | ||||||
|     python -m venv /ak-root/venv/ && \ |     python -m venv /ak-root/venv/ && \ | ||||||
|     pip3 install --upgrade pip && \ |     bash -c "source ${VENV_PATH}/bin/activate && \ | ||||||
|     pip3 install poetry && \ |         pip3 install --upgrade pip && \ | ||||||
|     poetry install --only=main --no-ansi --no-interaction |         pip3 install poetry && \ | ||||||
|  |         poetry install --only=main --no-ansi --no-interaction --no-root" | ||||||
|  |  | ||||||
| # Stage 6: Run | # Stage 6: Run | ||||||
| FROM docker.io/python:3.12.1-slim-bookworm AS final-image | FROM docker.io/python:3.12.2-slim-bookworm AS final-image | ||||||
|  |  | ||||||
| ARG GIT_BUILD_HASH | ARG GIT_BUILD_HASH | ||||||
| ARG VERSION | ARG VERSION | ||||||
|  | |||||||
							
								
								
									
										47
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										47
									
								
								Makefile
									
									
									
									
									
								
							| @ -5,9 +5,12 @@ PWD = $(shell pwd) | |||||||
| UID = $(shell id -u) | UID = $(shell id -u) | ||||||
| GID = $(shell id -g) | GID = $(shell id -g) | ||||||
| NPM_VERSION = $(shell python -m scripts.npm_version) | NPM_VERSION = $(shell python -m scripts.npm_version) | ||||||
| PY_SOURCES = authentik tests scripts lifecycle | PY_SOURCES = authentik tests scripts lifecycle .github | ||||||
| DOCKER_IMAGE ?= "authentik:test" | DOCKER_IMAGE ?= "authentik:test" | ||||||
|  |  | ||||||
|  | GEN_API_TS = "gen-ts-api" | ||||||
|  | GEN_API_GO = "gen-go-api" | ||||||
|  |  | ||||||
| pg_user := $(shell python -m authentik.lib.config postgresql.user 2>/dev/null) | pg_user := $(shell python -m authentik.lib.config postgresql.user 2>/dev/null) | ||||||
| pg_host := $(shell python -m authentik.lib.config postgresql.host 2>/dev/null) | pg_host := $(shell python -m authentik.lib.config postgresql.host 2>/dev/null) | ||||||
| pg_name := $(shell python -m authentik.lib.config postgresql.name 2>/dev/null) | pg_name := $(shell python -m authentik.lib.config postgresql.name 2>/dev/null) | ||||||
| @ -76,7 +79,15 @@ migrate: ## Run the Authentik Django server's migrations | |||||||
| i18n-extract: core-i18n-extract web-i18n-extract  ## Extract strings that require translation into files to send to a translation service | i18n-extract: core-i18n-extract web-i18n-extract  ## Extract strings that require translation into files to send to a translation service | ||||||
|  |  | ||||||
| core-i18n-extract: | core-i18n-extract: | ||||||
| 	ak makemessages --ignore web --ignore internal --ignore web --ignore web-api --ignore website -l en | 	ak makemessages \ | ||||||
|  | 		--add-location file \ | ||||||
|  | 		--no-obsolete \ | ||||||
|  | 		--ignore web \ | ||||||
|  | 		--ignore internal \ | ||||||
|  | 		--ignore ${GEN_API_TS} \ | ||||||
|  | 		--ignore ${GEN_API_GO} \ | ||||||
|  | 		--ignore website \ | ||||||
|  | 		-l en | ||||||
|  |  | ||||||
| install: web-install website-install core-install  ## Install all requires dependencies for `web`, `website` and `core` | install: web-install website-install core-install  ## Install all requires dependencies for `web`, `website` and `core` | ||||||
|  |  | ||||||
| @ -114,7 +125,7 @@ gen-diff:  ## (Release) generate the changelog diff between the current schema a | |||||||
| 	docker run \ | 	docker run \ | ||||||
| 		--rm -v ${PWD}:/local \ | 		--rm -v ${PWD}:/local \ | ||||||
| 		--user ${UID}:${GID} \ | 		--user ${UID}:${GID} \ | ||||||
| 		docker.io/openapitools/openapi-diff:2.1.0-beta.6 \ | 		docker.io/openapitools/openapi-diff:2.1.0-beta.8 \ | ||||||
| 		--markdown /local/diff.md \ | 		--markdown /local/diff.md \ | ||||||
| 		/local/old_schema.yml /local/schema.yml | 		/local/old_schema.yml /local/schema.yml | ||||||
| 	rm old_schema.yml | 	rm old_schema.yml | ||||||
| @ -123,11 +134,11 @@ gen-diff:  ## (Release) generate the changelog diff between the current schema a | |||||||
| 	npx prettier --write diff.md | 	npx prettier --write diff.md | ||||||
|  |  | ||||||
| gen-clean-ts:  ## Remove generated API client for Typescript | gen-clean-ts:  ## Remove generated API client for Typescript | ||||||
| 	rm -rf gen-ts-api/ | 	rm -rf ./${GEN_API_TS}/ | ||||||
| 	rm -rf web/node_modules/@goauthentik/api/ | 	rm -rf ./web/node_modules/@goauthentik/api/ | ||||||
|  |  | ||||||
| gen-clean-go:  ## Remove generated API client for Go | gen-clean-go:  ## Remove generated API client for Go | ||||||
| 	rm -rf gen-go-api/ | 	rm -rf ./${GEN_API_GO}/ | ||||||
|  |  | ||||||
| gen-clean: gen-clean-ts gen-clean-go  ## Remove generated API clients | gen-clean: gen-clean-ts gen-clean-go  ## Remove generated API clients | ||||||
|  |  | ||||||
| @ -138,31 +149,31 @@ gen-client-ts: gen-clean-ts  ## Build and install the authentik API for Typescri | |||||||
| 		docker.io/openapitools/openapi-generator-cli:v6.5.0 generate \ | 		docker.io/openapitools/openapi-generator-cli:v6.5.0 generate \ | ||||||
| 		-i /local/schema.yml \ | 		-i /local/schema.yml \ | ||||||
| 		-g typescript-fetch \ | 		-g typescript-fetch \ | ||||||
| 		-o /local/gen-ts-api \ | 		-o /local/${GEN_API_TS} \ | ||||||
| 		-c /local/scripts/api-ts-config.yaml \ | 		-c /local/scripts/api-ts-config.yaml \ | ||||||
| 		--additional-properties=npmVersion=${NPM_VERSION} \ | 		--additional-properties=npmVersion=${NPM_VERSION} \ | ||||||
| 		--git-repo-id authentik \ | 		--git-repo-id authentik \ | ||||||
| 		--git-user-id goauthentik | 		--git-user-id goauthentik | ||||||
| 	mkdir -p web/node_modules/@goauthentik/api | 	mkdir -p web/node_modules/@goauthentik/api | ||||||
| 	cd gen-ts-api && npm i | 	cd ./${GEN_API_TS} && npm i | ||||||
| 	\cp -rfv gen-ts-api/* web/node_modules/@goauthentik/api | 	\cp -rf ./${GEN_API_TS}/* web/node_modules/@goauthentik/api | ||||||
|  |  | ||||||
| gen-client-go: gen-clean-go  ## Build and install the authentik API for Golang | gen-client-go: gen-clean-go  ## Build and install the authentik API for Golang | ||||||
| 	mkdir -p ./gen-go-api ./gen-go-api/templates | 	mkdir -p ./${GEN_API_GO} ./${GEN_API_GO}/templates | ||||||
| 	wget https://raw.githubusercontent.com/goauthentik/client-go/main/config.yaml -O ./gen-go-api/config.yaml | 	wget https://raw.githubusercontent.com/goauthentik/client-go/main/config.yaml -O ./${GEN_API_GO}/config.yaml | ||||||
| 	wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/README.mustache -O ./gen-go-api/templates/README.mustache | 	wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/README.mustache -O ./${GEN_API_GO}/templates/README.mustache | ||||||
| 	wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/go.mod.mustache -O ./gen-go-api/templates/go.mod.mustache | 	wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/go.mod.mustache -O ./${GEN_API_GO}/templates/go.mod.mustache | ||||||
| 	cp schema.yml ./gen-go-api/ | 	cp schema.yml ./${GEN_API_GO}/ | ||||||
| 	docker run \ | 	docker run \ | ||||||
| 		--rm -v ${PWD}/gen-go-api:/local \ | 		--rm -v ${PWD}/${GEN_API_GO}:/local \ | ||||||
| 		--user ${UID}:${GID} \ | 		--user ${UID}:${GID} \ | ||||||
| 		docker.io/openapitools/openapi-generator-cli:v6.5.0 generate \ | 		docker.io/openapitools/openapi-generator-cli:v6.5.0 generate \ | ||||||
| 		-i /local/schema.yml \ | 		-i /local/schema.yml \ | ||||||
| 		-g go \ | 		-g go \ | ||||||
| 		-o /local/ \ | 		-o /local/ \ | ||||||
| 		-c /local/config.yaml | 		-c /local/config.yaml | ||||||
| 	go mod edit -replace goauthentik.io/api/v3=./gen-go-api | 	go mod edit -replace goauthentik.io/api/v3=./${GEN_API_GO} | ||||||
| 	rm -rf ./gen-go-api/config.yaml ./gen-go-api/templates/ | 	rm -rf ./${GEN_API_GO}/config.yaml ./${GEN_API_GO}/templates/ | ||||||
|  |  | ||||||
| gen-dev-config:  ## Generate a local development config file | gen-dev-config:  ## Generate a local development config file | ||||||
| 	python -m scripts.generate_config | 	python -m scripts.generate_config | ||||||
| @ -176,7 +187,7 @@ gen: gen-build gen-client-ts | |||||||
| web-build: web-install  ## Build the Authentik UI | web-build: web-install  ## Build the Authentik UI | ||||||
| 	cd web && npm run build | 	cd web && npm run build | ||||||
|  |  | ||||||
| web: web-lint-fix web-lint web-check-compile web-i18n-extract  ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it | web: web-lint-fix web-lint web-check-compile  ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it | ||||||
|  |  | ||||||
| web-install:  ## Install the necessary libraries to build the Authentik UI | web-install:  ## Install the necessary libraries to build the Authentik UI | ||||||
| 	cd web && npm ci | 	cd web && npm ci | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ | |||||||
| from os import environ | from os import environ | ||||||
| from typing import Optional | from typing import Optional | ||||||
|  |  | ||||||
| __version__ = "2023.10.7" | __version__ = "2024.2.4" | ||||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -15,7 +15,3 @@ class AuthentikAdminConfig(ManagedAppConfig): | |||||||
|     label = "authentik_admin" |     label = "authentik_admin" | ||||||
|     verbose_name = "authentik Admin" |     verbose_name = "authentik Admin" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_load_admin_signals(self): |  | ||||||
|         """Load admin signals""" |  | ||||||
|         self.import_module("authentik.admin.signals") |  | ||||||
|  | |||||||
| @ -1,35 +0,0 @@ | |||||||
| """test decorators api""" |  | ||||||
|  |  | ||||||
| from django.urls import reverse |  | ||||||
| from guardian.shortcuts import assign_perm |  | ||||||
| from rest_framework.test import APITestCase |  | ||||||
|  |  | ||||||
| from authentik.core.models import Application, User |  | ||||||
| from authentik.lib.generators import generate_id |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestAPIDecorators(APITestCase): |  | ||||||
|     """test decorators api""" |  | ||||||
|  |  | ||||||
|     def setUp(self) -> None: |  | ||||||
|         super().setUp() |  | ||||||
|         self.user = User.objects.create(username="test-user") |  | ||||||
|  |  | ||||||
|     def test_obj_perm_denied(self): |  | ||||||
|         """Test object perm denied""" |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|         app = Application.objects.create(name=generate_id(), slug=generate_id()) |  | ||||||
|         response = self.client.get( |  | ||||||
|             reverse("authentik_api:application-metrics", kwargs={"slug": app.slug}) |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 403) |  | ||||||
|  |  | ||||||
|     def test_other_perm_denied(self): |  | ||||||
|         """Test other perm denied""" |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|         app = Application.objects.create(name=generate_id(), slug=generate_id()) |  | ||||||
|         assign_perm("authentik_core.view_application", self.user, app) |  | ||||||
|         response = self.client.get( |  | ||||||
|             reverse("authentik_api:application-metrics", kwargs={"slug": app.slug}) |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 403) |  | ||||||
| @ -68,7 +68,11 @@ class ConfigView(APIView): | |||||||
|         """Get all capabilities this server instance supports""" |         """Get all capabilities this server instance supports""" | ||||||
|         caps = [] |         caps = [] | ||||||
|         deb_test = settings.DEBUG or settings.TEST |         deb_test = settings.DEBUG or settings.TEST | ||||||
|         if Path(settings.MEDIA_ROOT).is_mount() or deb_test: |         if ( | ||||||
|  |             CONFIG.get("storage.media.backend", "file") == "s3" | ||||||
|  |             or Path(settings.STORAGES["default"]["OPTIONS"]["location"]).is_mount() | ||||||
|  |             or deb_test | ||||||
|  |         ): | ||||||
|             caps.append(Capabilities.CAN_SAVE_MEDIA) |             caps.append(Capabilities.CAN_SAVE_MEDIA) | ||||||
|         for processor in get_context_processors(): |         for processor in get_context_processors(): | ||||||
|             if cap := processor.capability(): |             if cap := processor.capability(): | ||||||
|  | |||||||
| @ -10,13 +10,13 @@ from rest_framework.response import Response | |||||||
| from rest_framework.serializers import ListSerializer, ModelSerializer | from rest_framework.serializers import ListSerializer, ModelSerializer | ||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
|  |  | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.blueprints.models import BlueprintInstance | from authentik.blueprints.models import BlueprintInstance | ||||||
| from authentik.blueprints.v1.importer import Importer | from authentik.blueprints.v1.importer import Importer | ||||||
| from authentik.blueprints.v1.oci import OCI_PREFIX | from authentik.blueprints.v1.oci import OCI_PREFIX | ||||||
| from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_find_dict | from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_find_dict | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import JSONDictField, PassiveSerializer | from authentik.core.api.utils import JSONDictField, PassiveSerializer | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
|  |  | ||||||
| class ManagedSerializer: | class ManagedSerializer: | ||||||
|  | |||||||
| @ -21,10 +21,27 @@ class ManagedAppConfig(AppConfig): | |||||||
|         self.logger = get_logger().bind(app_name=app_name) |         self.logger = get_logger().bind(app_name=app_name) | ||||||
|  |  | ||||||
|     def ready(self) -> None: |     def ready(self) -> None: | ||||||
|  |         self.import_related() | ||||||
|         self.reconcile_global() |         self.reconcile_global() | ||||||
|         self.reconcile_tenant() |         self.reconcile_tenant() | ||||||
|         return super().ready() |         return super().ready() | ||||||
|  |  | ||||||
|  |     def import_related(self): | ||||||
|  |         """Automatically import related modules which rely on just being imported | ||||||
|  |         to register themselves (mainly django signals and celery tasks)""" | ||||||
|  |  | ||||||
|  |         def import_relative(rel_module: str): | ||||||
|  |             try: | ||||||
|  |                 module_name = f"{self.name}.{rel_module}" | ||||||
|  |                 import_module(module_name) | ||||||
|  |                 self.logger.info("Imported related module", module=module_name) | ||||||
|  |             except ModuleNotFoundError: | ||||||
|  |                 pass | ||||||
|  |  | ||||||
|  |         import_relative("checks") | ||||||
|  |         import_relative("tasks") | ||||||
|  |         import_relative("signals") | ||||||
|  |  | ||||||
|     def import_module(self, path: str): |     def import_module(self, path: str): | ||||||
|         """Load module""" |         """Load module""" | ||||||
|         import_module(path) |         import_module(path) | ||||||
|  | |||||||
| @ -74,7 +74,7 @@ class Exporter: | |||||||
|  |  | ||||||
|  |  | ||||||
| class FlowExporter(Exporter): | class FlowExporter(Exporter): | ||||||
|     """Exporter customised to only return objects related to `flow`""" |     """Exporter customized to only return objects related to `flow`""" | ||||||
|  |  | ||||||
|     flow: Flow |     flow: Flow | ||||||
|     with_policies: bool |     with_policies: bool | ||||||
|  | |||||||
| @ -39,7 +39,8 @@ from authentik.core.models import ( | |||||||
|     Source, |     Source, | ||||||
|     UserSourceConnection, |     UserSourceConnection, | ||||||
| ) | ) | ||||||
| from authentik.enterprise.models import LicenseKey, LicenseUsage | from authentik.enterprise.license import LicenseKey | ||||||
|  | from authentik.enterprise.models import LicenseUsage | ||||||
| from authentik.enterprise.providers.rac.models import ConnectionToken | from authentik.enterprise.providers.rac.models import ConnectionToken | ||||||
| from authentik.events.models import SystemTask | from authentik.events.models import SystemTask | ||||||
| from authentik.events.utils import cleanse_dict | from authentik.events.utils import cleanse_dict | ||||||
|  | |||||||
| @ -3,6 +3,7 @@ | |||||||
| from dataclasses import asdict, dataclass, field | from dataclasses import asdict, dataclass, field | ||||||
| from hashlib import sha512 | from hashlib import sha512 | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  | from sys import platform | ||||||
| from typing import Optional | from typing import Optional | ||||||
|  |  | ||||||
| from dacite.core import from_dict | from dacite.core import from_dict | ||||||
| @ -60,7 +61,12 @@ def start_blueprint_watcher(): | |||||||
|     if _file_watcher_started: |     if _file_watcher_started: | ||||||
|         return |         return | ||||||
|     observer = Observer() |     observer = Observer() | ||||||
|     observer.schedule(BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True) |     kwargs = {} | ||||||
|  |     if platform.startswith("linux"): | ||||||
|  |         kwargs["event_filter"] = (FileCreatedEvent, FileModifiedEvent) | ||||||
|  |     observer.schedule( | ||||||
|  |         BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs | ||||||
|  |     ) | ||||||
|     observer.start() |     observer.start() | ||||||
|     _file_watcher_started = True |     _file_watcher_started = True | ||||||
|  |  | ||||||
| @ -68,26 +74,36 @@ def start_blueprint_watcher(): | |||||||
| class BlueprintEventHandler(FileSystemEventHandler): | class BlueprintEventHandler(FileSystemEventHandler): | ||||||
|     """Event handler for blueprint events""" |     """Event handler for blueprint events""" | ||||||
|  |  | ||||||
|     def on_any_event(self, event: FileSystemEvent): |     # We only ever get creation and modification events. | ||||||
|         if not isinstance(event, (FileCreatedEvent, FileModifiedEvent)): |     # See the creation of the Observer instance above for the event filtering. | ||||||
|             return |  | ||||||
|  |     # Even though we filter to only get file events, we might still get | ||||||
|  |     # directory events as some implementations such as inotify do not support | ||||||
|  |     # filtering on file/directory. | ||||||
|  |  | ||||||
|  |     def dispatch(self, event: FileSystemEvent) -> None: | ||||||
|  |         """Call specific event handler method. Ignores directory changes.""" | ||||||
|         if event.is_directory: |         if event.is_directory: | ||||||
|             return |             return None | ||||||
|  |         return super().dispatch(event) | ||||||
|  |  | ||||||
|  |     def on_created(self, event: FileSystemEvent): | ||||||
|  |         """Process file creation""" | ||||||
|  |         LOGGER.debug("new blueprint file created, starting discovery") | ||||||
|  |         for tenant in Tenant.objects.filter(ready=True): | ||||||
|  |             with tenant: | ||||||
|  |                 blueprints_discovery.delay() | ||||||
|  |  | ||||||
|  |     def on_modified(self, event: FileSystemEvent): | ||||||
|  |         """Process file modification""" | ||||||
|  |         path = Path(event.src_path) | ||||||
|         root = Path(CONFIG.get("blueprints_dir")).absolute() |         root = Path(CONFIG.get("blueprints_dir")).absolute() | ||||||
|         path = Path(event.src_path).absolute() |  | ||||||
|         rel_path = str(path.relative_to(root)) |         rel_path = str(path.relative_to(root)) | ||||||
|         for tenant in Tenant.objects.filter(ready=True): |         for tenant in Tenant.objects.filter(ready=True): | ||||||
|             with tenant: |             with tenant: | ||||||
|                 root = Path(CONFIG.get("blueprints_dir")).absolute() |                 for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True): | ||||||
|                 path = Path(event.src_path).absolute() |                     LOGGER.debug("modified blueprint file, starting apply", instance=instance) | ||||||
|                 rel_path = str(path.relative_to(root)) |                     apply_blueprint.delay(instance.pk.hex) | ||||||
|                 if isinstance(event, FileCreatedEvent): |  | ||||||
|                     LOGGER.debug("new blueprint file created, starting discovery", path=rel_path) |  | ||||||
|                     blueprints_discovery.delay(rel_path) |  | ||||||
|                 if isinstance(event, FileModifiedEvent): |  | ||||||
|                     for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True): |  | ||||||
|                         LOGGER.debug("modified blueprint file, starting apply", instance=instance) |  | ||||||
|                         apply_blueprint.delay(instance.pk.hex) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task( | @CELERY_APP.task( | ||||||
|  | |||||||
| @ -9,6 +9,7 @@ from sentry_sdk.hub import Hub | |||||||
|  |  | ||||||
| from authentik import get_full_version | from authentik import get_full_version | ||||||
| from authentik.brands.models import Brand | from authentik.brands.models import Brand | ||||||
|  | from authentik.tenants.models import Tenant | ||||||
|  |  | ||||||
| _q_default = Q(default=True) | _q_default = Q(default=True) | ||||||
| DEFAULT_BRAND = Brand(domain="fallback") | DEFAULT_BRAND = Brand(domain="fallback") | ||||||
| @ -30,13 +31,14 @@ def get_brand_for_request(request: HttpRequest) -> Brand: | |||||||
| def context_processor(request: HttpRequest) -> dict[str, Any]: | def context_processor(request: HttpRequest) -> dict[str, Any]: | ||||||
|     """Context Processor that injects brand object into every template""" |     """Context Processor that injects brand object into every template""" | ||||||
|     brand = getattr(request, "brand", DEFAULT_BRAND) |     brand = getattr(request, "brand", DEFAULT_BRAND) | ||||||
|  |     tenant = getattr(request, "tenant", Tenant()) | ||||||
|     trace = "" |     trace = "" | ||||||
|     span = Hub.current.scope.span |     span = Hub.current.scope.span | ||||||
|     if span: |     if span: | ||||||
|         trace = span.to_traceparent() |         trace = span.to_traceparent() | ||||||
|     return { |     return { | ||||||
|         "brand": brand, |         "brand": brand, | ||||||
|         "footer_links": request.tenant.footer_links, |         "footer_links": tenant.footer_links, | ||||||
|         "sentry_trace": trace, |         "sentry_trace": trace, | ||||||
|         "version": get_full_version(), |         "version": get_full_version(), | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
|  |  | ||||||
| from copy import copy | from copy import copy | ||||||
| from datetime import timedelta | from datetime import timedelta | ||||||
| from typing import Optional | from typing import Iterator, Optional | ||||||
|  |  | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.db.models import QuerySet | from django.db.models import QuerySet | ||||||
| @ -23,7 +23,6 @@ from structlog.stdlib import get_logger | |||||||
| from structlog.testing import capture_logs | from structlog.testing import capture_logs | ||||||
|  |  | ||||||
| from authentik.admin.api.metrics import CoordinateSerializer | from authentik.admin.api.metrics import CoordinateSerializer | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | ||||||
| from authentik.core.api.providers import ProviderSerializer | from authentik.core.api.providers import ProviderSerializer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| @ -39,6 +38,7 @@ from authentik.lib.utils.file import ( | |||||||
| from authentik.policies.api.exec import PolicyTestResultSerializer | from authentik.policies.api.exec import PolicyTestResultSerializer | ||||||
| from authentik.policies.engine import PolicyEngine | from authentik.policies.engine import PolicyEngine | ||||||
| from authentik.policies.types import PolicyResult | from authentik.policies.types import PolicyResult | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
| from authentik.rbac.filters import ObjectFilter | from authentik.rbac.filters import ObjectFilter | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -131,14 +131,14 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | |||||||
|         return queryset |         return queryset | ||||||
|  |  | ||||||
|     def _get_allowed_applications( |     def _get_allowed_applications( | ||||||
|         self, queryset: QuerySet, user: Optional[User] = None |         self, pagined_apps: Iterator[Application], user: Optional[User] = None | ||||||
|     ) -> list[Application]: |     ) -> list[Application]: | ||||||
|         applications = [] |         applications = [] | ||||||
|         request = self.request._request |         request = self.request._request | ||||||
|         if user: |         if user: | ||||||
|             request = copy(request) |             request = copy(request) | ||||||
|             request.user = user |             request.user = user | ||||||
|         for application in queryset: |         for application in pagined_apps: | ||||||
|             engine = PolicyEngine(application, request.user, request) |             engine = PolicyEngine(application, request.user, request) | ||||||
|             engine.build() |             engine.build() | ||||||
|             if engine.passing: |             if engine.passing: | ||||||
| @ -215,7 +215,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | |||||||
|             return super().list(request) |             return super().list(request) | ||||||
|  |  | ||||||
|         queryset = self._filter_queryset_for_list(self.get_queryset()) |         queryset = self._filter_queryset_for_list(self.get_queryset()) | ||||||
|         self.paginate_queryset(queryset) |         pagined_apps = self.paginate_queryset(queryset) | ||||||
|  |  | ||||||
|         if "for_user" in request.query_params: |         if "for_user" in request.query_params: | ||||||
|             try: |             try: | ||||||
| @ -229,18 +229,18 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | |||||||
|                     raise ValidationError({"for_user": "User not found"}) |                     raise ValidationError({"for_user": "User not found"}) | ||||||
|             except ValueError as exc: |             except ValueError as exc: | ||||||
|                 raise ValidationError from exc |                 raise ValidationError from exc | ||||||
|             allowed_applications = self._get_allowed_applications(queryset, user=for_user) |             allowed_applications = self._get_allowed_applications(pagined_apps, user=for_user) | ||||||
|             serializer = self.get_serializer(allowed_applications, many=True) |             serializer = self.get_serializer(allowed_applications, many=True) | ||||||
|             return self.get_paginated_response(serializer.data) |             return self.get_paginated_response(serializer.data) | ||||||
|  |  | ||||||
|         allowed_applications = [] |         allowed_applications = [] | ||||||
|         if not should_cache: |         if not should_cache: | ||||||
|             allowed_applications = self._get_allowed_applications(queryset) |             allowed_applications = self._get_allowed_applications(pagined_apps) | ||||||
|         if should_cache: |         if should_cache: | ||||||
|             allowed_applications = cache.get(user_app_cache_key(self.request.user.pk)) |             allowed_applications = cache.get(user_app_cache_key(self.request.user.pk)) | ||||||
|             if not allowed_applications: |             if not allowed_applications: | ||||||
|                 LOGGER.debug("Caching allowed application list") |                 LOGGER.debug("Caching allowed application list") | ||||||
|                 allowed_applications = self._get_allowed_applications(queryset) |                 allowed_applications = self._get_allowed_applications(pagined_apps) | ||||||
|                 cache.set( |                 cache.set( | ||||||
|                     user_app_cache_key(self.request.user.pk), |                     user_app_cache_key(self.request.user.pk), | ||||||
|                     allowed_applications, |                     allowed_applications, | ||||||
|  | |||||||
| @ -15,11 +15,11 @@ from rest_framework.response import Response | |||||||
| from rest_framework.serializers import ListSerializer, ModelSerializer, ValidationError | from rest_framework.serializers import ListSerializer, ModelSerializer, ValidationError | ||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
|  |  | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import JSONDictField, PassiveSerializer | from authentik.core.api.utils import JSONDictField, PassiveSerializer | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.rbac.api.roles import RoleSerializer | from authentik.rbac.api.roles import RoleSerializer | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
|  |  | ||||||
| class GroupMemberSerializer(ModelSerializer): | class GroupMemberSerializer(ModelSerializer): | ||||||
|  | |||||||
| @ -14,7 +14,6 @@ from rest_framework.response import Response | |||||||
| from rest_framework.serializers import ModelSerializer, SerializerMethodField | from rest_framework.serializers import ModelSerializer, SerializerMethodField | ||||||
| from rest_framework.viewsets import GenericViewSet | from rest_framework.viewsets import GenericViewSet | ||||||
|  |  | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.blueprints.api import ManagedSerializer | from authentik.blueprints.api import ManagedSerializer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer | from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer | ||||||
| @ -23,6 +22,7 @@ from authentik.core.models import PropertyMapping | |||||||
| from authentik.events.utils import sanitize_item | from authentik.events.utils import sanitize_item | ||||||
| from authentik.lib.utils.reflection import all_subclasses | from authentik.lib.utils.reflection import all_subclasses | ||||||
| from authentik.policies.api.exec import PolicyTestSerializer | from authentik.policies.api.exec import PolicyTestSerializer | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
|  |  | ||||||
| class PropertyMappingTestResultSerializer(PassiveSerializer): | class PropertyMappingTestResultSerializer(PassiveSerializer): | ||||||
| @ -118,7 +118,11 @@ class PropertyMappingViewSet( | |||||||
|     @action(detail=True, pagination_class=None, filter_backends=[], methods=["POST"]) |     @action(detail=True, pagination_class=None, filter_backends=[], methods=["POST"]) | ||||||
|     def test(self, request: Request, pk: str) -> Response: |     def test(self, request: Request, pk: str) -> Response: | ||||||
|         """Test Property Mapping""" |         """Test Property Mapping""" | ||||||
|         mapping: PropertyMapping = self.get_object() |         _mapping: PropertyMapping = self.get_object() | ||||||
|  |         # Use `get_subclass` to get correct class and correct `.evaluate` implementation | ||||||
|  |         mapping = PropertyMapping.objects.get_subclass(pk=_mapping.pk) | ||||||
|  |         # FIXME: when we separate policy mappings between ones for sources | ||||||
|  |         # and ones for providers, we need to make the user field optional for the source mapping | ||||||
|         test_params = PolicyTestSerializer(data=request.data) |         test_params = PolicyTestSerializer(data=request.data) | ||||||
|         if not test_params.is_valid(): |         if not test_params.is_valid(): | ||||||
|             return Response(test_params.errors, status=400) |             return Response(test_params.errors, status=400) | ||||||
|  | |||||||
| @ -16,7 +16,6 @@ from rest_framework.viewsets import GenericViewSet | |||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions | from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import MetaNameSerializer, TypeCreateSerializer | from authentik.core.api.utils import MetaNameSerializer, TypeCreateSerializer | ||||||
| @ -30,6 +29,7 @@ from authentik.lib.utils.file import ( | |||||||
| ) | ) | ||||||
| from authentik.lib.utils.reflection import all_subclasses | from authentik.lib.utils.reflection import all_subclasses | ||||||
| from authentik.policies.engine import PolicyEngine | from authentik.policies.engine import PolicyEngine | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  | |||||||
| @ -15,15 +15,15 @@ from rest_framework.serializers import ModelSerializer | |||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
|  |  | ||||||
| from authentik.api.authorization import OwnerSuperuserPermissions | from authentik.api.authorization import OwnerSuperuserPermissions | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.blueprints.api import ManagedSerializer | from authentik.blueprints.api import ManagedSerializer | ||||||
| from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.users import UserSerializer | from authentik.core.api.users import UserSerializer | ||||||
| from authentik.core.api.utils import PassiveSerializer | from authentik.core.api.utils import PassiveSerializer | ||||||
| from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents | from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.events.utils import model_to_dict | from authentik.events.utils import model_to_dict | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
|  |  | ||||||
| class TokenSerializer(ManagedSerializer, ModelSerializer): | class TokenSerializer(ManagedSerializer, ModelSerializer): | ||||||
| @ -36,6 +36,13 @@ class TokenSerializer(ManagedSerializer, ModelSerializer): | |||||||
|         if SERIALIZER_CONTEXT_BLUEPRINT in self.context: |         if SERIALIZER_CONTEXT_BLUEPRINT in self.context: | ||||||
|             self.fields["key"] = CharField(required=False) |             self.fields["key"] = CharField(required=False) | ||||||
|  |  | ||||||
|  |     def validate_user(self, user: User): | ||||||
|  |         """Ensure user of token cannot be changed""" | ||||||
|  |         if self.instance and self.instance.user_id: | ||||||
|  |             if user.pk != self.instance.user_id: | ||||||
|  |                 raise ValidationError("User cannot be changed") | ||||||
|  |         return user | ||||||
|  |  | ||||||
|     def validate(self, attrs: dict[Any, str]) -> dict[Any, str]: |     def validate(self, attrs: dict[Any, str]) -> dict[Any, str]: | ||||||
|         """Ensure only API or App password tokens are created.""" |         """Ensure only API or App password tokens are created.""" | ||||||
|         request: Request = self.context.get("request") |         request: Request = self.context.get("request") | ||||||
|  | |||||||
| @ -49,7 +49,6 @@ from rest_framework.viewsets import ModelViewSet | |||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.admin.api.metrics import CoordinateSerializer | from authentik.admin.api.metrics import CoordinateSerializer | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | ||||||
| from authentik.brands.models import Brand | from authentik.brands.models import Brand | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| @ -74,6 +73,7 @@ from authentik.flows.models import FlowToken | |||||||
| from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner | from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner | ||||||
| from authentik.flows.views.executor import QS_KEY_TOKEN | from authentik.flows.views.executor import QS_KEY_TOKEN | ||||||
| from authentik.lib.avatars import get_avatar | from authentik.lib.avatars import get_avatar | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
| from authentik.stages.email.models import EmailStage | from authentik.stages.email.models import EmailStage | ||||||
| from authentik.stages.email.tasks import send_mails | from authentik.stages.email.tasks import send_mails | ||||||
| from authentik.stages.email.utils import TemplateEmailMessage | from authentik.stages.email.utils import TemplateEmailMessage | ||||||
| @ -154,7 +154,7 @@ class UserSerializer(ModelSerializer): | |||||||
|  |  | ||||||
|     def get_avatar(self, user: User) -> str: |     def get_avatar(self, user: User) -> str: | ||||||
|         """User's avatar, either a http/https URL or a data URI""" |         """User's avatar, either a http/https URL or a data URI""" | ||||||
|         return get_avatar(user, self.context["request"]) |         return get_avatar(user, self.context.get("request")) | ||||||
|  |  | ||||||
|     def validate_path(self, path: str) -> str: |     def validate_path(self, path: str) -> str: | ||||||
|         """Validate path""" |         """Validate path""" | ||||||
| @ -218,7 +218,7 @@ class UserSelfSerializer(ModelSerializer): | |||||||
|  |  | ||||||
|     def get_avatar(self, user: User) -> str: |     def get_avatar(self, user: User) -> str: | ||||||
|         """User's avatar, either a http/https URL or a data URI""" |         """User's avatar, either a http/https URL or a data URI""" | ||||||
|         return get_avatar(user, self.context["request"]) |         return get_avatar(user, self.context.get("request")) | ||||||
|  |  | ||||||
|     @extend_schema_field( |     @extend_schema_field( | ||||||
|         ListSerializer( |         ListSerializer( | ||||||
| @ -533,7 +533,7 @@ class UserViewSet(UsedByMixin, ModelViewSet): | |||||||
|             400: OpenApiResponse(description="Bad request"), |             400: OpenApiResponse(description="Bad request"), | ||||||
|         }, |         }, | ||||||
|     ) |     ) | ||||||
|     @action(detail=True, methods=["POST"]) |     @action(detail=True, methods=["POST"], permission_classes=[]) | ||||||
|     def set_password(self, request: Request, pk: int) -> Response: |     def set_password(self, request: Request, pk: int) -> Response: | ||||||
|         """Set password for user""" |         """Set password for user""" | ||||||
|         user: User = self.get_object() |         user: User = self.get_object() | ||||||
| @ -611,7 +611,7 @@ class UserViewSet(UsedByMixin, ModelViewSet): | |||||||
|         email_stage: EmailStage = stages.first() |         email_stage: EmailStage = stages.first() | ||||||
|         message = TemplateEmailMessage( |         message = TemplateEmailMessage( | ||||||
|             subject=_(email_stage.subject), |             subject=_(email_stage.subject), | ||||||
|             to=[for_user.email], |             to=[(for_user.name, for_user.email)], | ||||||
|             template_name=email_stage.template, |             template_name=email_stage.template, | ||||||
|             language=for_user.locale(request), |             language=for_user.locale(request), | ||||||
|             template_context={ |             template_context={ | ||||||
| @ -631,7 +631,7 @@ class UserViewSet(UsedByMixin, ModelViewSet): | |||||||
|             "401": OpenApiResponse(description="Access denied"), |             "401": OpenApiResponse(description="Access denied"), | ||||||
|         }, |         }, | ||||||
|     ) |     ) | ||||||
|     @action(detail=True, methods=["POST"]) |     @action(detail=True, methods=["POST"], permission_classes=[]) | ||||||
|     def impersonate(self, request: Request, pk: int) -> Response: |     def impersonate(self, request: Request, pk: int) -> Response: | ||||||
|         """Impersonate a user""" |         """Impersonate a user""" | ||||||
|         if not request.tenant.impersonation: |         if not request.tenant.impersonation: | ||||||
|  | |||||||
| @ -14,10 +14,6 @@ class AuthentikCoreConfig(ManagedAppConfig): | |||||||
|     mountpoint = "" |     mountpoint = "" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_load_core_signals(self): |  | ||||||
|         """Load core signals""" |  | ||||||
|         self.import_module("authentik.core.signals") |  | ||||||
|  |  | ||||||
|     def reconcile_global_debug_worker_hook(self): |     def reconcile_global_debug_worker_hook(self): | ||||||
|         """Dispatch startup tasks inline when debugging""" |         """Dispatch startup tasks inline when debugging""" | ||||||
|         if settings.DEBUG: |         if settings.DEBUG: | ||||||
|  | |||||||
| @ -43,7 +43,9 @@ class TokenBackend(InbuiltBackend): | |||||||
|         self, request: HttpRequest, username: Optional[str], password: Optional[str], **kwargs: Any |         self, request: HttpRequest, username: Optional[str], password: Optional[str], **kwargs: Any | ||||||
|     ) -> Optional[User]: |     ) -> Optional[User]: | ||||||
|         try: |         try: | ||||||
|  |             # pylint: disable=no-member | ||||||
|             user = User._default_manager.get_by_natural_key(username) |             user = User._default_manager.get_by_natural_key(username) | ||||||
|  |         # pylint: disable=no-member | ||||||
|         except User.DoesNotExist: |         except User.DoesNotExist: | ||||||
|             # Run the default password hasher once to reduce the timing |             # Run the default password hasher once to reduce the timing | ||||||
|             # difference between an existing and a nonexistent user (#20760). |             # difference between an existing and a nonexistent user (#20760). | ||||||
|  | |||||||
| @ -37,6 +37,7 @@ def clean_expired_models(self: SystemTask): | |||||||
|         messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}") |         messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}") | ||||||
|     # Special case |     # Special case | ||||||
|     amount = 0 |     amount = 0 | ||||||
|  |     # pylint: disable=no-member | ||||||
|     for session in AuthenticatedSession.objects.all(): |     for session in AuthenticatedSession.objects.all(): | ||||||
|         cache_key = f"{KEY_PREFIX}{session.session_key}" |         cache_key = f"{KEY_PREFIX}{session.session_key}" | ||||||
|         value = None |         value = None | ||||||
| @ -49,6 +50,7 @@ def clean_expired_models(self: SystemTask): | |||||||
|             session.delete() |             session.delete() | ||||||
|             amount += 1 |             amount += 1 | ||||||
|     LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount) |     LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount) | ||||||
|  |     # pylint: disable=no-member | ||||||
|     messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}") |     messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}") | ||||||
|     self.set_status(TaskStatus.SUCCESSFUL, *messages) |     self.set_status(TaskStatus.SUCCESSFUL, *messages) | ||||||
|  |  | ||||||
|  | |||||||
| @ -7,8 +7,8 @@ from guardian.shortcuts import get_anonymous_user | |||||||
| from rest_framework.test import APITestCase | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
| from authentik.core.api.tokens import TokenSerializer | from authentik.core.api.tokens import TokenSerializer | ||||||
| from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User | from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents | ||||||
| from authentik.core.tests.utils import create_test_admin_user | from authentik.core.tests.utils import create_test_admin_user, create_test_user | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -17,7 +17,7 @@ class TestTokenAPI(APITestCase): | |||||||
|  |  | ||||||
|     def setUp(self) -> None: |     def setUp(self) -> None: | ||||||
|         super().setUp() |         super().setUp() | ||||||
|         self.user = User.objects.create(username="testuser") |         self.user = create_test_user() | ||||||
|         self.admin = create_test_admin_user() |         self.admin = create_test_admin_user() | ||||||
|         self.client.force_login(self.user) |         self.client.force_login(self.user) | ||||||
|  |  | ||||||
| @ -76,6 +76,24 @@ class TestTokenAPI(APITestCase): | |||||||
|         self.assertEqual(token.intent, TokenIntents.INTENT_API) |         self.assertEqual(token.intent, TokenIntents.INTENT_API) | ||||||
|         self.assertEqual(token.expiring, False) |         self.assertEqual(token.expiring, False) | ||||||
|  |  | ||||||
|  |     def test_token_change_user(self): | ||||||
|  |         """Test creating a token and then changing the user""" | ||||||
|  |         ident = generate_id() | ||||||
|  |         response = self.client.post(reverse("authentik_api:token-list"), {"identifier": ident}) | ||||||
|  |         self.assertEqual(response.status_code, 201) | ||||||
|  |         token = Token.objects.get(identifier=ident) | ||||||
|  |         self.assertEqual(token.user, self.user) | ||||||
|  |         self.assertEqual(token.intent, TokenIntents.INTENT_API) | ||||||
|  |         self.assertEqual(token.expiring, True) | ||||||
|  |         self.assertTrue(self.user.has_perm("authentik_core.view_token_key", token)) | ||||||
|  |         response = self.client.put( | ||||||
|  |             reverse("authentik_api:token-detail", kwargs={"identifier": ident}), | ||||||
|  |             data={"identifier": "user_token_poc_v3", "intent": "api", "user": self.admin.pk}, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 400) | ||||||
|  |         token.refresh_from_db() | ||||||
|  |         self.assertEqual(token.user, self.user) | ||||||
|  |  | ||||||
|     def test_list(self): |     def test_list(self): | ||||||
|         """Test Token List (Test normal authentication)""" |         """Test Token List (Test normal authentication)""" | ||||||
|         Token.objects.all().delete() |         Token.objects.all().delete() | ||||||
|  | |||||||
| @ -24,13 +24,13 @@ from rest_framework.viewsets import ModelViewSet | |||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.api.authorization import SecretKeyFilter | from authentik.api.authorization import SecretKeyFilter | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import PassiveSerializer | from authentik.core.api.utils import PassiveSerializer | ||||||
| from authentik.crypto.apps import MANAGED_KEY | from authentik.crypto.apps import MANAGED_KEY | ||||||
| from authentik.crypto.builder import CertificateBuilder | from authentik.crypto.builder import CertificateBuilder | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| """authentik crypto app config""" | """authentik crypto app config""" | ||||||
|  |  | ||||||
| from datetime import datetime | from datetime import datetime, timezone | ||||||
| from typing import Optional | from typing import Optional | ||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
| @ -17,10 +17,6 @@ class AuthentikCryptoConfig(ManagedAppConfig): | |||||||
|     verbose_name = "authentik Crypto" |     verbose_name = "authentik Crypto" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_load_crypto_tasks(self): |  | ||||||
|         """Load crypto tasks""" |  | ||||||
|         self.import_module("authentik.crypto.tasks") |  | ||||||
|  |  | ||||||
|     def _create_update_cert(self): |     def _create_update_cert(self): | ||||||
|         from authentik.crypto.builder import CertificateBuilder |         from authentik.crypto.builder import CertificateBuilder | ||||||
|         from authentik.crypto.models import CertificateKeyPair |         from authentik.crypto.models import CertificateKeyPair | ||||||
| @ -47,9 +43,9 @@ class AuthentikCryptoConfig(ManagedAppConfig): | |||||||
|         cert: Optional[CertificateKeyPair] = CertificateKeyPair.objects.filter( |         cert: Optional[CertificateKeyPair] = CertificateKeyPair.objects.filter( | ||||||
|             managed=MANAGED_KEY |             managed=MANAGED_KEY | ||||||
|         ).first() |         ).first() | ||||||
|         now = datetime.now() |         now = datetime.now(tz=timezone.utc) | ||||||
|         if not cert or ( |         if not cert or ( | ||||||
|             now < cert.certificate.not_valid_before or now > cert.certificate.not_valid_after |             now < cert.certificate.not_valid_after_utc or now > cert.certificate.not_valid_after_utc | ||||||
|         ): |         ): | ||||||
|             self._create_update_cert() |             self._create_update_cert() | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| """Enterprise API Views""" | """Enterprise API Views""" | ||||||
|  |  | ||||||
| from datetime import datetime, timedelta | from dataclasses import asdict | ||||||
|  | from datetime import timedelta | ||||||
|  |  | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from django.utils.translation import gettext as _ | from django.utils.translation import gettext as _ | ||||||
| @ -8,29 +9,29 @@ from drf_spectacular.types import OpenApiTypes | |||||||
| from drf_spectacular.utils import extend_schema, inline_serializer | from drf_spectacular.utils import extend_schema, inline_serializer | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
| from rest_framework.exceptions import ValidationError | from rest_framework.exceptions import ValidationError | ||||||
| from rest_framework.fields import BooleanField, CharField, DateTimeField, IntegerField | from rest_framework.fields import CharField, IntegerField | ||||||
| from rest_framework.permissions import IsAuthenticated | from rest_framework.permissions import IsAuthenticated | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
| from rest_framework.serializers import ModelSerializer | from rest_framework.serializers import ModelSerializer | ||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
|  |  | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import PassiveSerializer | from authentik.core.api.utils import PassiveSerializer | ||||||
| from authentik.core.models import User, UserTypes | from authentik.core.models import User, UserTypes | ||||||
| from authentik.enterprise.models import License, LicenseKey | from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer | ||||||
|  | from authentik.enterprise.models import License | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
| from authentik.root.install_id import get_install_id | from authentik.root.install_id import get_install_id | ||||||
|  |  | ||||||
|  |  | ||||||
| class EnterpriseRequiredMixin: | class EnterpriseRequiredMixin: | ||||||
|     """Mixin to validate that a valid enterprise license |     """Mixin to validate that a valid enterprise license | ||||||
|     exists before allowing to safe the object""" |     exists before allowing to save the object""" | ||||||
|  |  | ||||||
|     def validate(self, attrs: dict) -> dict: |     def validate(self, attrs: dict) -> dict: | ||||||
|         """Check that a valid license exists""" |         """Check that a valid license exists""" | ||||||
|         total = LicenseKey.get_total() |         if not LicenseKey.cached_summary().has_license: | ||||||
|         if not total.is_valid(): |  | ||||||
|             raise ValidationError(_("Enterprise is required to create/update this object.")) |             raise ValidationError(_("Enterprise is required to create/update this object.")) | ||||||
|         return super().validate(attrs) |         return super().validate(attrs) | ||||||
|  |  | ||||||
| @ -61,19 +62,6 @@ class LicenseSerializer(ModelSerializer): | |||||||
|         } |         } | ||||||
|  |  | ||||||
|  |  | ||||||
| class LicenseSummary(PassiveSerializer): |  | ||||||
|     """Serializer for license status""" |  | ||||||
|  |  | ||||||
|     internal_users = IntegerField(required=True) |  | ||||||
|     external_users = IntegerField(required=True) |  | ||||||
|     valid = BooleanField() |  | ||||||
|     show_admin_warning = BooleanField() |  | ||||||
|     show_user_warning = BooleanField() |  | ||||||
|     read_only = BooleanField() |  | ||||||
|     latest_valid = DateTimeField() |  | ||||||
|     has_license = BooleanField() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class LicenseForecastSerializer(PassiveSerializer): | class LicenseForecastSerializer(PassiveSerializer): | ||||||
|     """Serializer for license forecast""" |     """Serializer for license forecast""" | ||||||
|  |  | ||||||
| @ -111,31 +99,13 @@ class LicenseViewSet(UsedByMixin, ModelViewSet): | |||||||
|     @extend_schema( |     @extend_schema( | ||||||
|         request=OpenApiTypes.NONE, |         request=OpenApiTypes.NONE, | ||||||
|         responses={ |         responses={ | ||||||
|             200: LicenseSummary(), |             200: LicenseSummarySerializer(), | ||||||
|         }, |         }, | ||||||
|     ) |     ) | ||||||
|     @action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated]) |     @action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated]) | ||||||
|     def summary(self, request: Request) -> Response: |     def summary(self, request: Request) -> Response: | ||||||
|         """Get the total license status""" |         """Get the total license status""" | ||||||
|         total = LicenseKey.get_total() |         response = LicenseSummarySerializer(data=asdict(LicenseKey.cached_summary())) | ||||||
|         last_valid = LicenseKey.last_valid_date() |  | ||||||
|         # TODO: move this to a different place? |  | ||||||
|         show_admin_warning = last_valid < now() - timedelta(weeks=2) |  | ||||||
|         show_user_warning = last_valid < now() - timedelta(weeks=4) |  | ||||||
|         read_only = last_valid < now() - timedelta(weeks=6) |  | ||||||
|         latest_valid = datetime.fromtimestamp(total.exp) |  | ||||||
|         response = LicenseSummary( |  | ||||||
|             data={ |  | ||||||
|                 "internal_users": total.internal_users, |  | ||||||
|                 "external_users": total.external_users, |  | ||||||
|                 "valid": total.is_valid(), |  | ||||||
|                 "show_admin_warning": show_admin_warning, |  | ||||||
|                 "show_user_warning": show_user_warning, |  | ||||||
|                 "read_only": read_only, |  | ||||||
|                 "latest_valid": latest_valid, |  | ||||||
|                 "has_license": License.objects.all().count() > 0, |  | ||||||
|             } |  | ||||||
|         ) |  | ||||||
|         response.is_valid(raise_exception=True) |         response.is_valid(raise_exception=True) | ||||||
|         return Response(response.data) |         return Response(response.data) | ||||||
|  |  | ||||||
|  | |||||||
| @ -17,16 +17,12 @@ class AuthentikEnterpriseConfig(EnterpriseConfig): | |||||||
|     verbose_name = "authentik Enterprise" |     verbose_name = "authentik Enterprise" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_load_enterprise_signals(self): |  | ||||||
|         """Load enterprise signals""" |  | ||||||
|         self.import_module("authentik.enterprise.signals") |  | ||||||
|  |  | ||||||
|     def enabled(self): |     def enabled(self): | ||||||
|         """Return true if enterprise is enabled and valid""" |         """Return true if enterprise is enabled and valid""" | ||||||
|         return self.check_enabled() or settings.TEST |         return self.check_enabled() or settings.TEST | ||||||
|  |  | ||||||
|     def check_enabled(self): |     def check_enabled(self): | ||||||
|         """Actual enterprise check, cached""" |         """Actual enterprise check, cached""" | ||||||
|         from authentik.enterprise.models import LicenseKey |         from authentik.enterprise.license import LicenseKey | ||||||
|  |  | ||||||
|         return LicenseKey.get_total().is_valid() |         return LicenseKey.cached_summary().valid | ||||||
|  | |||||||
| @ -11,7 +11,6 @@ from django.db.models.expressions import BaseExpression, Combinable | |||||||
| from django.db.models.signals import post_init | from django.db.models.signals import post_init | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
|  |  | ||||||
| from authentik.core.models import User |  | ||||||
| from authentik.events.middleware import AuditMiddleware, should_log_model | from authentik.events.middleware import AuditMiddleware, should_log_model | ||||||
| from authentik.events.utils import cleanse_dict, sanitize_item | from authentik.events.utils import cleanse_dict, sanitize_item | ||||||
|  |  | ||||||
| @ -19,26 +18,19 @@ from authentik.events.utils import cleanse_dict, sanitize_item | |||||||
| class EnterpriseAuditMiddleware(AuditMiddleware): | class EnterpriseAuditMiddleware(AuditMiddleware): | ||||||
|     """Enterprise audit middleware""" |     """Enterprise audit middleware""" | ||||||
|  |  | ||||||
|     _enabled = None |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def enabled(self): |     def enabled(self): | ||||||
|         """Lazy check if audit logging is enabled""" |         """Check if audit logging is enabled""" | ||||||
|         if self._enabled is None: |         return apps.get_app_config("authentik_enterprise").enabled() | ||||||
|             self._enabled = apps.get_app_config("authentik_enterprise").enabled() |  | ||||||
|         return self._enabled |  | ||||||
|  |  | ||||||
|     def connect(self, request: HttpRequest): |     def connect(self, request: HttpRequest): | ||||||
|         super().connect(request) |         super().connect(request) | ||||||
|         if not self.enabled: |         if not self.enabled: | ||||||
|             return |             return | ||||||
|         user = getattr(request, "user", self.anonymous_user) |  | ||||||
|         if not user.is_authenticated: |  | ||||||
|             user = self.anonymous_user |  | ||||||
|         if not hasattr(request, "request_id"): |         if not hasattr(request, "request_id"): | ||||||
|             return |             return | ||||||
|         post_init.connect( |         post_init.connect( | ||||||
|             partial(self.post_init_handler, user=user, request=request), |             partial(self.post_init_handler, request=request), | ||||||
|             dispatch_uid=request.request_id, |             dispatch_uid=request.request_id, | ||||||
|             weak=False, |             weak=False, | ||||||
|         ) |         ) | ||||||
| @ -80,7 +72,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | |||||||
|                 diff[key] = {"previous_value": value, "new_value": after.get(key)} |                 diff[key] = {"previous_value": value, "new_value": after.get(key)} | ||||||
|         return sanitize_item(diff) |         return sanitize_item(diff) | ||||||
|  |  | ||||||
|     def post_init_handler(self, user: User, request: HttpRequest, sender, instance: Model, **_): |     def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_): | ||||||
|         """post_init django model handler""" |         """post_init django model handler""" | ||||||
|         if not should_log_model(instance): |         if not should_log_model(instance): | ||||||
|             return |             return | ||||||
| @ -95,7 +87,6 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | |||||||
|     # pylint: disable=too-many-arguments |     # pylint: disable=too-many-arguments | ||||||
|     def post_save_handler( |     def post_save_handler( | ||||||
|         self, |         self, | ||||||
|         user: User, |  | ||||||
|         request: HttpRequest, |         request: HttpRequest, | ||||||
|         sender, |         sender, | ||||||
|         instance: Model, |         instance: Model, | ||||||
| @ -117,6 +108,4 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | |||||||
|                 for field_set in ignored_field_sets: |                 for field_set in ignored_field_sets: | ||||||
|                     if set(diff.keys()) == set(field_set): |                     if set(diff.keys()) == set(field_set): | ||||||
|                         return None |                         return None | ||||||
|         return super().post_save_handler( |         return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_) | ||||||
|             user, request, sender, instance, created, thread_kwargs, **_ |  | ||||||
|         ) |  | ||||||
|  | |||||||
							
								
								
									
										214
									
								
								authentik/enterprise/license.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										214
									
								
								authentik/enterprise/license.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,214 @@ | |||||||
|  | """Enterprise license""" | ||||||
|  |  | ||||||
|  | from base64 import b64decode | ||||||
|  | from binascii import Error | ||||||
|  | from dataclasses import asdict, dataclass, field | ||||||
|  | from datetime import datetime, timedelta | ||||||
|  | from enum import Enum | ||||||
|  | from functools import lru_cache | ||||||
|  | from time import mktime | ||||||
|  |  | ||||||
|  | from cryptography.exceptions import InvalidSignature | ||||||
|  | from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate | ||||||
|  | from dacite import from_dict | ||||||
|  | from django.core.cache import cache | ||||||
|  | from django.db.models.query import QuerySet | ||||||
|  | from django.utils.timezone import now | ||||||
|  | from jwt import PyJWTError, decode, get_unverified_header | ||||||
|  | from rest_framework.exceptions import ValidationError | ||||||
|  | from rest_framework.fields import BooleanField, DateTimeField, IntegerField | ||||||
|  |  | ||||||
|  | from authentik.core.api.utils import PassiveSerializer | ||||||
|  | from authentik.core.models import User, UserTypes | ||||||
|  | from authentik.enterprise.models import License, LicenseUsage | ||||||
|  | from authentik.root.install_id import get_install_id | ||||||
|  |  | ||||||
|  | CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license" | ||||||
|  | CACHE_EXPIRY_ENTERPRISE_LICENSE = 3 * 60 * 60  # 2 Hours | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @lru_cache() | ||||||
|  | def get_licensing_key() -> Certificate: | ||||||
|  |     """Get Root CA PEM""" | ||||||
|  |     with open("authentik/enterprise/public.pem", "rb") as _key: | ||||||
|  |         return load_pem_x509_certificate(_key.read()) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_license_aud() -> str: | ||||||
|  |     """Get the JWT audience field""" | ||||||
|  |     return f"enterprise.goauthentik.io/license/{get_install_id()}" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LicenseFlags(Enum): | ||||||
|  |     """License flags""" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @dataclass | ||||||
|  | class LicenseSummary: | ||||||
|  |     """Internal representation of a license summary""" | ||||||
|  |  | ||||||
|  |     internal_users: int | ||||||
|  |     external_users: int | ||||||
|  |     valid: bool | ||||||
|  |     show_admin_warning: bool | ||||||
|  |     show_user_warning: bool | ||||||
|  |     read_only: bool | ||||||
|  |     latest_valid: datetime | ||||||
|  |     has_license: bool | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class LicenseSummarySerializer(PassiveSerializer): | ||||||
|  |     """Serializer for license status""" | ||||||
|  |  | ||||||
|  |     internal_users = IntegerField(required=True) | ||||||
|  |     external_users = IntegerField(required=True) | ||||||
|  |     valid = BooleanField() | ||||||
|  |     show_admin_warning = BooleanField() | ||||||
|  |     show_user_warning = BooleanField() | ||||||
|  |     read_only = BooleanField() | ||||||
|  |     latest_valid = DateTimeField() | ||||||
|  |     has_license = BooleanField() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @dataclass | ||||||
|  | class LicenseKey: | ||||||
|  |     """License JWT claims""" | ||||||
|  |  | ||||||
|  |     aud: str | ||||||
|  |     exp: int | ||||||
|  |  | ||||||
|  |     name: str | ||||||
|  |     internal_users: int = 0 | ||||||
|  |     external_users: int = 0 | ||||||
|  |     flags: list[LicenseFlags] = field(default_factory=list) | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def validate(jwt: str) -> "LicenseKey": | ||||||
|  |         """Validate the license from a given JWT""" | ||||||
|  |         try: | ||||||
|  |             headers = get_unverified_header(jwt) | ||||||
|  |         except PyJWTError: | ||||||
|  |             raise ValidationError("Unable to verify license") | ||||||
|  |         x5c: list[str] = headers.get("x5c", []) | ||||||
|  |         if len(x5c) < 1: | ||||||
|  |             raise ValidationError("Unable to verify license") | ||||||
|  |         try: | ||||||
|  |             our_cert = load_der_x509_certificate(b64decode(x5c[0])) | ||||||
|  |             intermediate = load_der_x509_certificate(b64decode(x5c[1])) | ||||||
|  |             our_cert.verify_directly_issued_by(intermediate) | ||||||
|  |             intermediate.verify_directly_issued_by(get_licensing_key()) | ||||||
|  |         except (InvalidSignature, TypeError, ValueError, Error): | ||||||
|  |             raise ValidationError("Unable to verify license") | ||||||
|  |         try: | ||||||
|  |             body = from_dict( | ||||||
|  |                 LicenseKey, | ||||||
|  |                 decode( | ||||||
|  |                     jwt, | ||||||
|  |                     our_cert.public_key(), | ||||||
|  |                     algorithms=["ES512"], | ||||||
|  |                     audience=get_license_aud(), | ||||||
|  |                 ), | ||||||
|  |             ) | ||||||
|  |         except PyJWTError: | ||||||
|  |             raise ValidationError("Unable to verify license") | ||||||
|  |         return body | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def get_total() -> "LicenseKey": | ||||||
|  |         """Get a summarized version of all (not expired) licenses""" | ||||||
|  |         active_licenses = License.objects.filter(expiry__gte=now()) | ||||||
|  |         total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0) | ||||||
|  |         for lic in active_licenses: | ||||||
|  |             total.internal_users += lic.internal_users | ||||||
|  |             total.external_users += lic.external_users | ||||||
|  |             exp_ts = int(mktime(lic.expiry.timetuple())) | ||||||
|  |             if total.exp == 0: | ||||||
|  |                 total.exp = exp_ts | ||||||
|  |             if exp_ts <= total.exp: | ||||||
|  |                 total.exp = exp_ts | ||||||
|  |             total.flags.extend(lic.status.flags) | ||||||
|  |         return total | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def base_user_qs() -> QuerySet: | ||||||
|  |         """Base query set for all users""" | ||||||
|  |         return User.objects.all().exclude_anonymous().exclude(is_active=False) | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def get_default_user_count(): | ||||||
|  |         """Get current default user count""" | ||||||
|  |         return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count() | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def get_external_user_count(): | ||||||
|  |         """Get current external user count""" | ||||||
|  |         # Count since start of the month | ||||||
|  |         last_month = now().replace(day=1) | ||||||
|  |         return ( | ||||||
|  |             LicenseKey.base_user_qs() | ||||||
|  |             .filter(type=UserTypes.EXTERNAL, last_login__gte=last_month) | ||||||
|  |             .count() | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def is_valid(self) -> bool: | ||||||
|  |         """Check if the given license body covers all users | ||||||
|  |  | ||||||
|  |         Only checks the current count, no historical data is checked""" | ||||||
|  |         default_users = self.get_default_user_count() | ||||||
|  |         if default_users > self.internal_users: | ||||||
|  |             return False | ||||||
|  |         active_users = self.get_external_user_count() | ||||||
|  |         if active_users > self.external_users: | ||||||
|  |             return False | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     def record_usage(self): | ||||||
|  |         """Capture the current validity status and metrics and save them""" | ||||||
|  |         threshold = now() - timedelta(hours=8) | ||||||
|  |         if not LicenseUsage.objects.filter(record_date__gte=threshold).exists(): | ||||||
|  |             LicenseUsage.objects.create( | ||||||
|  |                 user_count=self.get_default_user_count(), | ||||||
|  |                 external_user_count=self.get_external_user_count(), | ||||||
|  |                 within_limits=self.is_valid(), | ||||||
|  |             ) | ||||||
|  |         summary = asdict(self.summary()) | ||||||
|  |         # Also cache the latest summary for the middleware | ||||||
|  |         cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE) | ||||||
|  |         return summary | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def last_valid_date() -> datetime: | ||||||
|  |         """Get the last date the license was valid""" | ||||||
|  |         usage: LicenseUsage = ( | ||||||
|  |             LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first() | ||||||
|  |         ) | ||||||
|  |         if not usage: | ||||||
|  |             return now() | ||||||
|  |         return usage.record_date | ||||||
|  |  | ||||||
|  |     def summary(self) -> LicenseSummary: | ||||||
|  |         """Summary of license status""" | ||||||
|  |         has_license = License.objects.all().count() > 0 | ||||||
|  |         last_valid = LicenseKey.last_valid_date() | ||||||
|  |         show_admin_warning = last_valid < now() - timedelta(weeks=2) | ||||||
|  |         show_user_warning = last_valid < now() - timedelta(weeks=4) | ||||||
|  |         read_only = last_valid < now() - timedelta(weeks=6) | ||||||
|  |         latest_valid = datetime.fromtimestamp(self.exp) | ||||||
|  |         return LicenseSummary( | ||||||
|  |             show_admin_warning=show_admin_warning and has_license, | ||||||
|  |             show_user_warning=show_user_warning and has_license, | ||||||
|  |             read_only=read_only and has_license, | ||||||
|  |             latest_valid=latest_valid, | ||||||
|  |             internal_users=self.internal_users, | ||||||
|  |             external_users=self.external_users, | ||||||
|  |             valid=self.is_valid(), | ||||||
|  |             has_license=has_license, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def cached_summary() -> LicenseSummary: | ||||||
|  |         """Helper method which looks up the last summary""" | ||||||
|  |         summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE) | ||||||
|  |         if not summary: | ||||||
|  |             return LicenseKey.get_total().summary() | ||||||
|  |         return from_dict(LicenseSummary, summary) | ||||||
							
								
								
									
										64
									
								
								authentik/enterprise/middleware.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								authentik/enterprise/middleware.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,64 @@ | |||||||
|  | """Enterprise middleware""" | ||||||
|  |  | ||||||
|  | from collections.abc import Callable | ||||||
|  |  | ||||||
|  | from django.http import HttpRequest, HttpResponse, JsonResponse | ||||||
|  | from django.urls import resolve | ||||||
|  | from structlog.stdlib import BoundLogger, get_logger | ||||||
|  |  | ||||||
|  | from authentik.enterprise.api import LicenseViewSet | ||||||
|  | from authentik.enterprise.license import LicenseKey | ||||||
|  | from authentik.flows.views.executor import FlowExecutorView | ||||||
|  | from authentik.lib.utils.reflection import class_to_path | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class EnterpriseMiddleware: | ||||||
|  |     """Enterprise middleware""" | ||||||
|  |  | ||||||
|  |     get_response: Callable[[HttpRequest], HttpResponse] | ||||||
|  |     logger: BoundLogger | ||||||
|  |  | ||||||
|  |     def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): | ||||||
|  |         self.get_response = get_response | ||||||
|  |         self.logger = get_logger().bind() | ||||||
|  |  | ||||||
|  |     def __call__(self, request: HttpRequest) -> HttpResponse: | ||||||
|  |         resolver_match = resolve(request.path_info) | ||||||
|  |         request.resolver_match = resolver_match | ||||||
|  |         if not self.is_request_allowed(request): | ||||||
|  |             self.logger.warning("Refusing request due to expired/invalid license") | ||||||
|  |             return JsonResponse( | ||||||
|  |                 { | ||||||
|  |                     "detail": "Request denied due to expired/invalid license.", | ||||||
|  |                     "code": "denied_license", | ||||||
|  |                 }, | ||||||
|  |                 status=400, | ||||||
|  |             ) | ||||||
|  |         return self.get_response(request) | ||||||
|  |  | ||||||
|  |     def is_request_allowed(self, request: HttpRequest) -> bool: | ||||||
|  |         """Check if a specific request is allowed""" | ||||||
|  |         if self.is_request_always_allowed(request): | ||||||
|  |             return True | ||||||
|  |         cached_status = LicenseKey.cached_summary() | ||||||
|  |         if not cached_status: | ||||||
|  |             return True | ||||||
|  |         if cached_status.read_only: | ||||||
|  |             return False | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     def is_request_always_allowed(self, request: HttpRequest): | ||||||
|  |         """Check if a request is always allowed""" | ||||||
|  |         # Always allow "safe" methods | ||||||
|  |         if request.method.lower() in ["get", "head", "options", "trace"]: | ||||||
|  |             return True | ||||||
|  |         # Always allow requests to manage licenses | ||||||
|  |         if class_to_path(request.resolver_match.func) == class_to_path(LicenseViewSet): | ||||||
|  |             return True | ||||||
|  |         # Flow executor is mounted as an API path but explicitly allowed | ||||||
|  |         if class_to_path(request.resolver_match.func) == class_to_path(FlowExecutorView): | ||||||
|  |             return True | ||||||
|  |         # Only apply these restrictions to the API | ||||||
|  |         if "authentik_api" not in request.resolver_match.app_names: | ||||||
|  |             return True | ||||||
|  |         return False | ||||||
| @ -1,159 +1,20 @@ | |||||||
| """Enterprise models""" | """Enterprise models""" | ||||||
|  |  | ||||||
| from base64 import b64decode | from datetime import timedelta | ||||||
| from binascii import Error | from typing import TYPE_CHECKING | ||||||
| from dataclasses import dataclass, field |  | ||||||
| from datetime import datetime, timedelta |  | ||||||
| from enum import Enum |  | ||||||
| from functools import lru_cache |  | ||||||
| from time import mktime |  | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from cryptography.exceptions import InvalidSignature |  | ||||||
| from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate |  | ||||||
| from dacite import from_dict |  | ||||||
| from django.contrib.postgres.indexes import HashIndex | from django.contrib.postgres.indexes import HashIndex | ||||||
| from django.db import models | from django.db import models | ||||||
| from django.db.models.query import QuerySet |  | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from django.utils.translation import gettext as _ | from django.utils.translation import gettext as _ | ||||||
| from jwt import PyJWTError, decode, get_unverified_header |  | ||||||
| from rest_framework.exceptions import ValidationError |  | ||||||
| from rest_framework.serializers import BaseSerializer | from rest_framework.serializers import BaseSerializer | ||||||
|  |  | ||||||
| from authentik.core.models import ExpiringModel, User, UserTypes | from authentik.core.models import ExpiringModel | ||||||
| from authentik.lib.models import SerializerModel | from authentik.lib.models import SerializerModel | ||||||
| from authentik.root.install_id import get_install_id |  | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
| @lru_cache() |     from authentik.enterprise.license import LicenseKey | ||||||
| def get_licensing_key() -> Certificate: |  | ||||||
|     """Get Root CA PEM""" |  | ||||||
|     with open("authentik/enterprise/public.pem", "rb") as _key: |  | ||||||
|         return load_pem_x509_certificate(_key.read()) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_license_aud() -> str: |  | ||||||
|     """Get the JWT audience field""" |  | ||||||
|     return f"enterprise.goauthentik.io/license/{get_install_id()}" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class LicenseFlags(Enum): |  | ||||||
|     """License flags""" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass |  | ||||||
| class LicenseKey: |  | ||||||
|     """License JWT claims""" |  | ||||||
|  |  | ||||||
|     aud: str |  | ||||||
|     exp: int |  | ||||||
|  |  | ||||||
|     name: str |  | ||||||
|     internal_users: int = 0 |  | ||||||
|     external_users: int = 0 |  | ||||||
|     flags: list[LicenseFlags] = field(default_factory=list) |  | ||||||
|  |  | ||||||
|     @staticmethod |  | ||||||
|     def validate(jwt: str) -> "LicenseKey": |  | ||||||
|         """Validate the license from a given JWT""" |  | ||||||
|         try: |  | ||||||
|             headers = get_unverified_header(jwt) |  | ||||||
|         except PyJWTError: |  | ||||||
|             raise ValidationError("Unable to verify license") |  | ||||||
|         x5c: list[str] = headers.get("x5c", []) |  | ||||||
|         if len(x5c) < 1: |  | ||||||
|             raise ValidationError("Unable to verify license") |  | ||||||
|         try: |  | ||||||
|             our_cert = load_der_x509_certificate(b64decode(x5c[0])) |  | ||||||
|             intermediate = load_der_x509_certificate(b64decode(x5c[1])) |  | ||||||
|             our_cert.verify_directly_issued_by(intermediate) |  | ||||||
|             intermediate.verify_directly_issued_by(get_licensing_key()) |  | ||||||
|         except (InvalidSignature, TypeError, ValueError, Error): |  | ||||||
|             raise ValidationError("Unable to verify license") |  | ||||||
|         try: |  | ||||||
|             body = from_dict( |  | ||||||
|                 LicenseKey, |  | ||||||
|                 decode( |  | ||||||
|                     jwt, |  | ||||||
|                     our_cert.public_key(), |  | ||||||
|                     algorithms=["ES512"], |  | ||||||
|                     audience=get_license_aud(), |  | ||||||
|                 ), |  | ||||||
|             ) |  | ||||||
|         except PyJWTError: |  | ||||||
|             raise ValidationError("Unable to verify license") |  | ||||||
|         return body |  | ||||||
|  |  | ||||||
|     @staticmethod |  | ||||||
|     def get_total() -> "LicenseKey": |  | ||||||
|         """Get a summarized version of all (not expired) licenses""" |  | ||||||
|         active_licenses = License.objects.filter(expiry__gte=now()) |  | ||||||
|         total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0) |  | ||||||
|         for lic in active_licenses: |  | ||||||
|             total.internal_users += lic.internal_users |  | ||||||
|             total.external_users += lic.external_users |  | ||||||
|             exp_ts = int(mktime(lic.expiry.timetuple())) |  | ||||||
|             if total.exp == 0: |  | ||||||
|                 total.exp = exp_ts |  | ||||||
|             if exp_ts <= total.exp: |  | ||||||
|                 total.exp = exp_ts |  | ||||||
|             total.flags.extend(lic.status.flags) |  | ||||||
|         return total |  | ||||||
|  |  | ||||||
|     @staticmethod |  | ||||||
|     def base_user_qs() -> QuerySet: |  | ||||||
|         """Base query set for all users""" |  | ||||||
|         return User.objects.all().exclude_anonymous().exclude(is_active=False) |  | ||||||
|  |  | ||||||
|     @staticmethod |  | ||||||
|     def get_default_user_count(): |  | ||||||
|         """Get current default user count""" |  | ||||||
|         return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count() |  | ||||||
|  |  | ||||||
|     @staticmethod |  | ||||||
|     def get_external_user_count(): |  | ||||||
|         """Get current external user count""" |  | ||||||
|         # Count since start of the month |  | ||||||
|         last_month = now().replace(day=1) |  | ||||||
|         return ( |  | ||||||
|             LicenseKey.base_user_qs() |  | ||||||
|             .filter(type=UserTypes.EXTERNAL, last_login__gte=last_month) |  | ||||||
|             .count() |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def is_valid(self) -> bool: |  | ||||||
|         """Check if the given license body covers all users |  | ||||||
|  |  | ||||||
|         Only checks the current count, no historical data is checked""" |  | ||||||
|         default_users = self.get_default_user_count() |  | ||||||
|         if default_users > self.internal_users: |  | ||||||
|             return False |  | ||||||
|         active_users = self.get_external_user_count() |  | ||||||
|         if active_users > self.external_users: |  | ||||||
|             return False |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     def record_usage(self): |  | ||||||
|         """Capture the current validity status and metrics and save them""" |  | ||||||
|         threshold = now() - timedelta(hours=8) |  | ||||||
|         if LicenseUsage.objects.filter(record_date__gte=threshold).exists(): |  | ||||||
|             return |  | ||||||
|         LicenseUsage.objects.create( |  | ||||||
|             user_count=self.get_default_user_count(), |  | ||||||
|             external_user_count=self.get_external_user_count(), |  | ||||||
|             within_limits=self.is_valid(), |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     @staticmethod |  | ||||||
|     def last_valid_date() -> datetime: |  | ||||||
|         """Get the last date the license was valid""" |  | ||||||
|         usage: LicenseUsage = ( |  | ||||||
|             LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first() |  | ||||||
|         ) |  | ||||||
|         if not usage: |  | ||||||
|             return now() |  | ||||||
|         return usage.record_date |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class License(SerializerModel): | class License(SerializerModel): | ||||||
| @ -174,8 +35,10 @@ class License(SerializerModel): | |||||||
|         return LicenseSerializer |         return LicenseSerializer | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def status(self) -> LicenseKey: |     def status(self) -> "LicenseKey": | ||||||
|         """Get parsed license status""" |         """Get parsed license status""" | ||||||
|  |         from authentik.enterprise.license import LicenseKey | ||||||
|  |  | ||||||
|         return LicenseKey.validate(self.key) |         return LicenseKey.validate(self.key) | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|  | |||||||
| @ -5,7 +5,7 @@ from typing import Optional | |||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
|  |  | ||||||
| from authentik.core.models import User, UserTypes | from authentik.core.models import User, UserTypes | ||||||
| from authentik.enterprise.models import LicenseKey | from authentik.enterprise.license import LicenseKey | ||||||
| from authentik.policies.types import PolicyRequest, PolicyResult | from authentik.policies.types import PolicyRequest, PolicyResult | ||||||
| from authentik.policies.views import PolicyAccessView | from authentik.policies.views import PolicyAccessView | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										53
									
								
								authentik/enterprise/providers/rac/api/connection_tokens.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								authentik/enterprise/providers/rac/api/connection_tokens.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,53 @@ | |||||||
|  | """RAC Provider API Views""" | ||||||
|  |  | ||||||
|  | from django_filters.rest_framework.backends import DjangoFilterBackend | ||||||
|  | from rest_framework import mixins | ||||||
|  | from rest_framework.filters import OrderingFilter, SearchFilter | ||||||
|  | from rest_framework.serializers import ModelSerializer | ||||||
|  | from rest_framework.viewsets import GenericViewSet | ||||||
|  |  | ||||||
|  | from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions | ||||||
|  | from authentik.core.api.groups import GroupMemberSerializer | ||||||
|  | from authentik.core.api.used_by import UsedByMixin | ||||||
|  | from authentik.enterprise.api import EnterpriseRequiredMixin | ||||||
|  | from authentik.enterprise.providers.rac.api.endpoints import EndpointSerializer | ||||||
|  | from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer | ||||||
|  | from authentik.enterprise.providers.rac.models import ConnectionToken | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ConnectionTokenSerializer(EnterpriseRequiredMixin, ModelSerializer): | ||||||
|  |     """ConnectionToken Serializer""" | ||||||
|  |  | ||||||
|  |     provider_obj = RACProviderSerializer(source="provider", read_only=True) | ||||||
|  |     endpoint_obj = EndpointSerializer(source="endpoint", read_only=True) | ||||||
|  |     user = GroupMemberSerializer(source="session.user", read_only=True) | ||||||
|  |  | ||||||
|  |     class Meta: | ||||||
|  |         model = ConnectionToken | ||||||
|  |         fields = [ | ||||||
|  |             "pk", | ||||||
|  |             "provider", | ||||||
|  |             "provider_obj", | ||||||
|  |             "endpoint", | ||||||
|  |             "endpoint_obj", | ||||||
|  |             "user", | ||||||
|  |         ] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ConnectionTokenViewSet( | ||||||
|  |     mixins.RetrieveModelMixin, | ||||||
|  |     mixins.UpdateModelMixin, | ||||||
|  |     mixins.DestroyModelMixin, | ||||||
|  |     UsedByMixin, | ||||||
|  |     mixins.ListModelMixin, | ||||||
|  |     GenericViewSet, | ||||||
|  | ): | ||||||
|  |     """ConnectionToken Viewset""" | ||||||
|  |  | ||||||
|  |     queryset = ConnectionToken.objects.all().select_related("session", "endpoint") | ||||||
|  |     serializer_class = ConnectionTokenSerializer | ||||||
|  |     filterset_fields = ["endpoint", "session__user", "provider"] | ||||||
|  |     search_fields = ["endpoint__name", "provider__name"] | ||||||
|  |     ordering = ["endpoint__name", "provider__name"] | ||||||
|  |     permission_classes = [OwnerSuperuserPermissions] | ||||||
|  |     filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter] | ||||||
| @ -16,7 +16,12 @@ class RACProviderSerializer(EnterpriseRequiredMixin, ProviderSerializer): | |||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         model = RACProvider |         model = RACProvider | ||||||
|         fields = ProviderSerializer.Meta.fields + ["settings", "outpost_set", "connection_expiry"] |         fields = ProviderSerializer.Meta.fields + [ | ||||||
|  |             "settings", | ||||||
|  |             "outpost_set", | ||||||
|  |             "connection_expiry", | ||||||
|  |             "delete_token_on_disconnect", | ||||||
|  |         ] | ||||||
|         extra_kwargs = ProviderSerializer.Meta.extra_kwargs |         extra_kwargs = ProviderSerializer.Meta.extra_kwargs | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -12,7 +12,3 @@ class AuthentikEnterpriseProviderRAC(EnterpriseConfig): | |||||||
|     default = True |     default = True | ||||||
|     mountpoint = "" |     mountpoint = "" | ||||||
|     ws_mountpoint = "authentik.enterprise.providers.rac.urls" |     ws_mountpoint = "authentik.enterprise.providers.rac.urls" | ||||||
|  |  | ||||||
|     def reconcile_global_load_rac_signals(self): |  | ||||||
|         """Load rac signals""" |  | ||||||
|         self.import_module("authentik.enterprise.providers.rac.signals") |  | ||||||
|  | |||||||
| @ -43,6 +43,7 @@ class RACClientConsumer(AsyncWebsocketConsumer): | |||||||
|     logger: BoundLogger |     logger: BoundLogger | ||||||
|  |  | ||||||
|     async def connect(self): |     async def connect(self): | ||||||
|  |         self.logger = get_logger() | ||||||
|         await self.accept("guacamole") |         await self.accept("guacamole") | ||||||
|         await self.channel_layer.group_add(RAC_CLIENT_GROUP, self.channel_name) |         await self.channel_layer.group_add(RAC_CLIENT_GROUP, self.channel_name) | ||||||
|         await self.channel_layer.group_add( |         await self.channel_layer.group_add( | ||||||
| @ -64,9 +65,11 @@ class RACClientConsumer(AsyncWebsocketConsumer): | |||||||
|     @database_sync_to_async |     @database_sync_to_async | ||||||
|     def init_outpost_connection(self): |     def init_outpost_connection(self): | ||||||
|         """Initialize guac connection settings""" |         """Initialize guac connection settings""" | ||||||
|         self.token = ConnectionToken.filter_not_expired( |         self.token = ( | ||||||
|             token=self.scope["url_route"]["kwargs"]["token"] |             ConnectionToken.filter_not_expired(token=self.scope["url_route"]["kwargs"]["token"]) | ||||||
|         ).first() |             .select_related("endpoint", "provider", "session", "session__user") | ||||||
|  |             .first() | ||||||
|  |         ) | ||||||
|         if not self.token: |         if not self.token: | ||||||
|             raise DenyConnection() |             raise DenyConnection() | ||||||
|         self.provider = self.token.provider |         self.provider = self.token.provider | ||||||
| @ -107,6 +110,9 @@ class RACClientConsumer(AsyncWebsocketConsumer): | |||||||
|                 OUTPOST_GROUP_INSTANCE % {"outpost_pk": str(outpost.pk), "instance": states[0].uid}, |                 OUTPOST_GROUP_INSTANCE % {"outpost_pk": str(outpost.pk), "instance": states[0].uid}, | ||||||
|                 msg, |                 msg, | ||||||
|             ) |             ) | ||||||
|  |         if self.provider and self.provider.delete_token_on_disconnect: | ||||||
|  |             self.logger.info("Deleting connection token to prevent reconnect", token=self.token) | ||||||
|  |             self.token.delete() | ||||||
|  |  | ||||||
|     async def receive(self, text_data=None, bytes_data=None): |     async def receive(self, text_data=None, bytes_data=None): | ||||||
|         """Mirror data received from client to the dest_channel_id |         """Mirror data received from client to the dest_channel_id | ||||||
|  | |||||||
| @ -0,0 +1,181 @@ | |||||||
|  | # Generated by Django 5.0.1 on 2024-02-11 19:04 | ||||||
|  |  | ||||||
|  | import uuid | ||||||
|  |  | ||||||
|  | import django.db.models.deletion | ||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  | import authentik.core.models | ||||||
|  | import authentik.lib.utils.time | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     replaces = [ | ||||||
|  |         ("authentik_providers_rac", "0001_initial"), | ||||||
|  |         ("authentik_providers_rac", "0002_endpoint_maximum_connections"), | ||||||
|  |         ("authentik_providers_rac", "0003_alter_connectiontoken_options_and_more"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     initial = True | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_core", "0032_group_roles"), | ||||||
|  |         ("authentik_policies", "0011_policybinding_failure_result_and_more"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.CreateModel( | ||||||
|  |             name="RACPropertyMapping", | ||||||
|  |             fields=[ | ||||||
|  |                 ( | ||||||
|  |                     "propertymapping_ptr", | ||||||
|  |                     models.OneToOneField( | ||||||
|  |                         auto_created=True, | ||||||
|  |                         on_delete=django.db.models.deletion.CASCADE, | ||||||
|  |                         parent_link=True, | ||||||
|  |                         primary_key=True, | ||||||
|  |                         serialize=False, | ||||||
|  |                         to="authentik_core.propertymapping", | ||||||
|  |                     ), | ||||||
|  |                 ), | ||||||
|  |                 ("static_settings", models.JSONField(default=dict)), | ||||||
|  |             ], | ||||||
|  |             options={ | ||||||
|  |                 "verbose_name": "RAC Property Mapping", | ||||||
|  |                 "verbose_name_plural": "RAC Property Mappings", | ||||||
|  |             }, | ||||||
|  |             bases=("authentik_core.propertymapping",), | ||||||
|  |         ), | ||||||
|  |         migrations.CreateModel( | ||||||
|  |             name="RACProvider", | ||||||
|  |             fields=[ | ||||||
|  |                 ( | ||||||
|  |                     "provider_ptr", | ||||||
|  |                     models.OneToOneField( | ||||||
|  |                         auto_created=True, | ||||||
|  |                         on_delete=django.db.models.deletion.CASCADE, | ||||||
|  |                         parent_link=True, | ||||||
|  |                         primary_key=True, | ||||||
|  |                         serialize=False, | ||||||
|  |                         to="authentik_core.provider", | ||||||
|  |                     ), | ||||||
|  |                 ), | ||||||
|  |                 ("settings", models.JSONField(default=dict)), | ||||||
|  |                 ( | ||||||
|  |                     "auth_mode", | ||||||
|  |                     models.TextField( | ||||||
|  |                         choices=[("static", "Static"), ("prompt", "Prompt")], default="prompt" | ||||||
|  |                     ), | ||||||
|  |                 ), | ||||||
|  |                 ( | ||||||
|  |                     "connection_expiry", | ||||||
|  |                     models.TextField( | ||||||
|  |                         default="hours=8", | ||||||
|  |                         help_text="Determines how long a session lasts. Default of 0 means that the sessions lasts until the browser is closed. (Format: hours=-1;minutes=-2;seconds=-3)", | ||||||
|  |                         validators=[authentik.lib.utils.time.timedelta_string_validator], | ||||||
|  |                     ), | ||||||
|  |                 ), | ||||||
|  |                 ( | ||||||
|  |                     "delete_token_on_disconnect", | ||||||
|  |                     models.BooleanField( | ||||||
|  |                         default=False, | ||||||
|  |                         help_text="When set to true, connection tokens will be deleted upon disconnect.", | ||||||
|  |                     ), | ||||||
|  |                 ), | ||||||
|  |             ], | ||||||
|  |             options={ | ||||||
|  |                 "verbose_name": "RAC Provider", | ||||||
|  |                 "verbose_name_plural": "RAC Providers", | ||||||
|  |             }, | ||||||
|  |             bases=("authentik_core.provider",), | ||||||
|  |         ), | ||||||
|  |         migrations.CreateModel( | ||||||
|  |             name="Endpoint", | ||||||
|  |             fields=[ | ||||||
|  |                 ( | ||||||
|  |                     "policybindingmodel_ptr", | ||||||
|  |                     models.OneToOneField( | ||||||
|  |                         auto_created=True, | ||||||
|  |                         on_delete=django.db.models.deletion.CASCADE, | ||||||
|  |                         parent_link=True, | ||||||
|  |                         primary_key=True, | ||||||
|  |                         serialize=False, | ||||||
|  |                         to="authentik_policies.policybindingmodel", | ||||||
|  |                     ), | ||||||
|  |                 ), | ||||||
|  |                 ("name", models.TextField()), | ||||||
|  |                 ("host", models.TextField()), | ||||||
|  |                 ( | ||||||
|  |                     "protocol", | ||||||
|  |                     models.TextField(choices=[("rdp", "Rdp"), ("vnc", "Vnc"), ("ssh", "Ssh")]), | ||||||
|  |                 ), | ||||||
|  |                 ("settings", models.JSONField(default=dict)), | ||||||
|  |                 ( | ||||||
|  |                     "auth_mode", | ||||||
|  |                     models.TextField(choices=[("static", "Static"), ("prompt", "Prompt")]), | ||||||
|  |                 ), | ||||||
|  |                 ( | ||||||
|  |                     "property_mappings", | ||||||
|  |                     models.ManyToManyField( | ||||||
|  |                         blank=True, default=None, to="authentik_core.propertymapping" | ||||||
|  |                     ), | ||||||
|  |                 ), | ||||||
|  |                 ( | ||||||
|  |                     "provider", | ||||||
|  |                     models.ForeignKey( | ||||||
|  |                         on_delete=django.db.models.deletion.CASCADE, | ||||||
|  |                         to="authentik_providers_rac.racprovider", | ||||||
|  |                     ), | ||||||
|  |                 ), | ||||||
|  |                 ("maximum_connections", models.IntegerField(default=1)), | ||||||
|  |             ], | ||||||
|  |             options={ | ||||||
|  |                 "verbose_name": "RAC Endpoint", | ||||||
|  |                 "verbose_name_plural": "RAC Endpoints", | ||||||
|  |             }, | ||||||
|  |             bases=("authentik_policies.policybindingmodel", models.Model), | ||||||
|  |         ), | ||||||
|  |         migrations.CreateModel( | ||||||
|  |             name="ConnectionToken", | ||||||
|  |             fields=[ | ||||||
|  |                 ( | ||||||
|  |                     "expires", | ||||||
|  |                     models.DateTimeField(default=authentik.core.models.default_token_duration), | ||||||
|  |                 ), | ||||||
|  |                 ("expiring", models.BooleanField(default=True)), | ||||||
|  |                 ( | ||||||
|  |                     "connection_token_uuid", | ||||||
|  |                     models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False), | ||||||
|  |                 ), | ||||||
|  |                 ("token", models.TextField(default=authentik.core.models.default_token_key)), | ||||||
|  |                 ("settings", models.JSONField(default=dict)), | ||||||
|  |                 ( | ||||||
|  |                     "endpoint", | ||||||
|  |                     models.ForeignKey( | ||||||
|  |                         on_delete=django.db.models.deletion.CASCADE, | ||||||
|  |                         to="authentik_providers_rac.endpoint", | ||||||
|  |                     ), | ||||||
|  |                 ), | ||||||
|  |                 ( | ||||||
|  |                     "provider", | ||||||
|  |                     models.ForeignKey( | ||||||
|  |                         on_delete=django.db.models.deletion.CASCADE, | ||||||
|  |                         to="authentik_providers_rac.racprovider", | ||||||
|  |                     ), | ||||||
|  |                 ), | ||||||
|  |                 ( | ||||||
|  |                     "session", | ||||||
|  |                     models.ForeignKey( | ||||||
|  |                         on_delete=django.db.models.deletion.CASCADE, | ||||||
|  |                         to="authentik_core.authenticatedsession", | ||||||
|  |                     ), | ||||||
|  |                 ), | ||||||
|  |             ], | ||||||
|  |             options={ | ||||||
|  |                 "abstract": False, | ||||||
|  |                 "verbose_name": "RAC Connection token", | ||||||
|  |                 "verbose_name_plural": "RAC Connection tokens", | ||||||
|  |             }, | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -0,0 +1,28 @@ | |||||||
|  | # Generated by Django 5.0.1 on 2024-02-11 19:04 | ||||||
|  |  | ||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_providers_rac", "0002_endpoint_maximum_connections"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.AlterModelOptions( | ||||||
|  |             name="connectiontoken", | ||||||
|  |             options={ | ||||||
|  |                 "verbose_name": "RAC Connection token", | ||||||
|  |                 "verbose_name_plural": "RAC Connection tokens", | ||||||
|  |             }, | ||||||
|  |         ), | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="racprovider", | ||||||
|  |             name="delete_token_on_disconnect", | ||||||
|  |             field=models.BooleanField( | ||||||
|  |                 default=False, | ||||||
|  |                 help_text="When set to true, connection tokens will be deleted upon disconnect.", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -1,17 +1,18 @@ | |||||||
| """RAC Models""" | """RAC Models""" | ||||||
|  |  | ||||||
| from typing import Optional | from typing import Any, Optional | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from deepmerge import always_merger | from deepmerge import always_merger | ||||||
| from django.db import models | from django.db import models | ||||||
| from django.db.models import QuerySet | from django.db.models import QuerySet | ||||||
|  | from django.http import HttpRequest | ||||||
| from django.utils.translation import gettext as _ | from django.utils.translation import gettext as _ | ||||||
| from rest_framework.serializers import Serializer | from rest_framework.serializers import Serializer | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.exceptions import PropertyMappingExpressionException | from authentik.core.exceptions import PropertyMappingExpressionException | ||||||
| from authentik.core.models import ExpiringModel, PropertyMapping, Provider, default_token_key | from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User, default_token_key | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.lib.models import SerializerModel | from authentik.lib.models import SerializerModel | ||||||
| from authentik.lib.utils.time import timedelta_string_validator | from authentik.lib.utils.time import timedelta_string_validator | ||||||
| @ -51,6 +52,10 @@ class RACProvider(Provider): | |||||||
|             "(Format: hours=-1;minutes=-2;seconds=-3)" |             "(Format: hours=-1;minutes=-2;seconds=-3)" | ||||||
|         ), |         ), | ||||||
|     ) |     ) | ||||||
|  |     delete_token_on_disconnect = models.BooleanField( | ||||||
|  |         default=False, | ||||||
|  |         help_text=_("When set to true, connection tokens will be deleted upon disconnect."), | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def launch_url(self) -> Optional[str]: |     def launch_url(self) -> Optional[str]: | ||||||
| @ -107,6 +112,12 @@ class RACPropertyMapping(PropertyMapping): | |||||||
|  |  | ||||||
|     static_settings = models.JSONField(default=dict) |     static_settings = models.JSONField(default=dict) | ||||||
|  |  | ||||||
|  |     def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any: | ||||||
|  |         """Evaluate `self.expression` using `**kwargs` as Context.""" | ||||||
|  |         if len(self.static_settings) > 0: | ||||||
|  |             return self.static_settings | ||||||
|  |         return super().evaluate(user, request, **kwargs) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def component(self) -> str: |     def component(self) -> str: | ||||||
|         return "ak-property-mapping-rac-form" |         return "ak-property-mapping-rac-form" | ||||||
| @ -155,9 +166,6 @@ class ConnectionToken(ExpiringModel): | |||||||
|         def mapping_evaluator(mappings: QuerySet): |         def mapping_evaluator(mappings: QuerySet): | ||||||
|             for mapping in mappings: |             for mapping in mappings: | ||||||
|                 mapping: RACPropertyMapping |                 mapping: RACPropertyMapping | ||||||
|                 if len(mapping.static_settings) > 0: |  | ||||||
|                     always_merger.merge(settings, mapping.static_settings) |  | ||||||
|                     continue |  | ||||||
|                 try: |                 try: | ||||||
|                     mapping_settings = mapping.evaluate( |                     mapping_settings = mapping.evaluate( | ||||||
|                         self.session.user, None, endpoint=self.endpoint, provider=self.provider |                         self.session.user, None, endpoint=self.endpoint, provider=self.provider | ||||||
| @ -191,3 +199,13 @@ class ConnectionToken(ExpiringModel): | |||||||
|                 continue |                 continue | ||||||
|             settings[key] = str(value) |             settings[key] = str(value) | ||||||
|         return settings |         return settings | ||||||
|  |  | ||||||
|  |     def __str__(self): | ||||||
|  |         return ( | ||||||
|  |             f"RAC Connection token {self.session.user} to " | ||||||
|  |             f"{self.endpoint.provider.name}/{self.endpoint.name}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     class Meta: | ||||||
|  |         verbose_name = _("RAC Connection token") | ||||||
|  |         verbose_name_plural = _("RAC Connection tokens") | ||||||
|  | |||||||
| @ -45,8 +45,8 @@ def pre_delete_connection_token_disconnect(sender, instance: ConnectionToken, ** | |||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(post_save, sender=Endpoint) | @receiver(post_save, sender=Endpoint) | ||||||
| def post_save_application(sender: type[Model], instance, created: bool, **_): | def post_save_endpoint(sender: type[Model], instance, created: bool, **_): | ||||||
|     """Clear user's application cache upon application creation""" |     """Clear user's endpoint cache upon endpoint creation""" | ||||||
|     if not created:  # pragma: no cover |     if not created:  # pragma: no cover | ||||||
|         return |         return | ||||||
|  |  | ||||||
|  | |||||||
| @ -70,6 +70,7 @@ class TestEndpointsAPI(APITestCase): | |||||||
|                             "authorization_flow": None, |                             "authorization_flow": None, | ||||||
|                             "property_mappings": [], |                             "property_mappings": [], | ||||||
|                             "connection_expiry": "hours=8", |                             "connection_expiry": "hours=8", | ||||||
|  |                             "delete_token_on_disconnect": False, | ||||||
|                             "component": "ak-provider-rac-form", |                             "component": "ak-provider-rac-form", | ||||||
|                             "assigned_application_slug": self.app.slug, |                             "assigned_application_slug": self.app.slug, | ||||||
|                             "assigned_application_name": self.app.name, |                             "assigned_application_name": self.app.name, | ||||||
| @ -124,6 +125,7 @@ class TestEndpointsAPI(APITestCase): | |||||||
|                             "assigned_application_slug": self.app.slug, |                             "assigned_application_slug": self.app.slug, | ||||||
|                             "assigned_application_name": self.app.name, |                             "assigned_application_name": self.app.name, | ||||||
|                             "connection_expiry": "hours=8", |                             "connection_expiry": "hours=8", | ||||||
|  |                             "delete_token_on_disconnect": False, | ||||||
|                             "verbose_name": "RAC Provider", |                             "verbose_name": "RAC Provider", | ||||||
|                             "verbose_name_plural": "RAC Providers", |                             "verbose_name_plural": "RAC Providers", | ||||||
|                             "meta_model_name": "authentik_providers_rac.racprovider", |                             "meta_model_name": "authentik_providers_rac.racprovider", | ||||||
| @ -152,6 +154,7 @@ class TestEndpointsAPI(APITestCase): | |||||||
|                             "assigned_application_slug": self.app.slug, |                             "assigned_application_slug": self.app.slug, | ||||||
|                             "assigned_application_name": self.app.name, |                             "assigned_application_name": self.app.name, | ||||||
|                             "connection_expiry": "hours=8", |                             "connection_expiry": "hours=8", | ||||||
|  |                             "delete_token_on_disconnect": False, | ||||||
|                             "verbose_name": "RAC Provider", |                             "verbose_name": "RAC Provider", | ||||||
|                             "verbose_name_plural": "RAC Providers", |                             "verbose_name_plural": "RAC Providers", | ||||||
|                             "meta_model_name": "authentik_providers_rac.racprovider", |                             "meta_model_name": "authentik_providers_rac.racprovider", | ||||||
|  | |||||||
| @ -11,7 +11,8 @@ from rest_framework.test import APITestCase | |||||||
|  |  | ||||||
| from authentik.core.models import Application | from authentik.core.models import Application | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow | from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||||
| from authentik.enterprise.models import License, LicenseKey | from authentik.enterprise.license import LicenseKey | ||||||
|  | from authentik.enterprise.models import License | ||||||
| from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider | from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.policies.denied import AccessDeniedResponse | from authentik.policies.denied import AccessDeniedResponse | ||||||
| @ -39,7 +40,7 @@ class TestRACViews(APITestCase): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     @patch( |     @patch( | ||||||
|         "authentik.enterprise.models.LicenseKey.validate", |         "authentik.enterprise.license.LicenseKey.validate", | ||||||
|         MagicMock( |         MagicMock( | ||||||
|             return_value=LicenseKey( |             return_value=LicenseKey( | ||||||
|                 aud="", |                 aud="", | ||||||
| @ -70,7 +71,7 @@ class TestRACViews(APITestCase): | |||||||
|         self.assertEqual(final_response.status_code, 200) |         self.assertEqual(final_response.status_code, 200) | ||||||
|  |  | ||||||
|     @patch( |     @patch( | ||||||
|         "authentik.enterprise.models.LicenseKey.validate", |         "authentik.enterprise.license.LicenseKey.validate", | ||||||
|         MagicMock( |         MagicMock( | ||||||
|             return_value=LicenseKey( |             return_value=LicenseKey( | ||||||
|                 aud="", |                 aud="", | ||||||
| @ -99,7 +100,7 @@ class TestRACViews(APITestCase): | |||||||
|         self.assertIsInstance(response, AccessDeniedResponse) |         self.assertIsInstance(response, AccessDeniedResponse) | ||||||
|  |  | ||||||
|     @patch( |     @patch( | ||||||
|         "authentik.enterprise.models.LicenseKey.validate", |         "authentik.enterprise.license.LicenseKey.validate", | ||||||
|         MagicMock( |         MagicMock( | ||||||
|             return_value=LicenseKey( |             return_value=LicenseKey( | ||||||
|                 aud="", |                 aud="", | ||||||
|  | |||||||
| @ -6,6 +6,7 @@ from django.urls import path | |||||||
| from django.views.decorators.csrf import ensure_csrf_cookie | from django.views.decorators.csrf import ensure_csrf_cookie | ||||||
|  |  | ||||||
| from authentik.core.channels import TokenOutpostMiddleware | from authentik.core.channels import TokenOutpostMiddleware | ||||||
|  | from authentik.enterprise.providers.rac.api.connection_tokens import ConnectionTokenViewSet | ||||||
| from authentik.enterprise.providers.rac.api.endpoints import EndpointViewSet | from authentik.enterprise.providers.rac.api.endpoints import EndpointViewSet | ||||||
| from authentik.enterprise.providers.rac.api.property_mappings import RACPropertyMappingViewSet | from authentik.enterprise.providers.rac.api.property_mappings import RACPropertyMappingViewSet | ||||||
| from authentik.enterprise.providers.rac.api.providers import RACProviderViewSet | from authentik.enterprise.providers.rac.api.providers import RACProviderViewSet | ||||||
| @ -45,4 +46,5 @@ api_urlpatterns = [ | |||||||
|     ("providers/rac", RACProviderViewSet), |     ("providers/rac", RACProviderViewSet), | ||||||
|     ("propertymappings/rac", RACPropertyMappingViewSet), |     ("propertymappings/rac", RACPropertyMappingViewSet), | ||||||
|     ("rac/endpoints", EndpointViewSet), |     ("rac/endpoints", EndpointViewSet), | ||||||
|  |     ("rac/connection_tokens", ConnectionTokenViewSet), | ||||||
| ] | ] | ||||||
|  | |||||||
| @ -104,14 +104,15 @@ class RACFinalStage(RedirectStage): | |||||||
|         # Check if we're already at the maximum connection limit |         # Check if we're already at the maximum connection limit | ||||||
|         all_tokens = ConnectionToken.filter_not_expired( |         all_tokens = ConnectionToken.filter_not_expired( | ||||||
|             endpoint=self.endpoint, |             endpoint=self.endpoint, | ||||||
|         ).exclude(endpoint__maximum_connections__lte=-1) |         ) | ||||||
|         if all_tokens.count() >= self.endpoint.maximum_connections: |         if self.endpoint.maximum_connections > -1: | ||||||
|             msg = [_("Maximum connection limit reached.")] |             if all_tokens.count() >= self.endpoint.maximum_connections: | ||||||
|             # Check if any other tokens exist for the current user, and inform them |                 msg = [_("Maximum connection limit reached.")] | ||||||
|             # they are already connected |                 # Check if any other tokens exist for the current user, and inform them | ||||||
|             if all_tokens.filter(session__user=self.request.user).exists(): |                 # they are already connected | ||||||
|                 msg.append(_("(You are already connected in another tab/window)")) |                 if all_tokens.filter(session__user=self.request.user).exists(): | ||||||
|             return self.executor.stage_invalid(" ".join(msg)) |                     msg.append(_("(You are already connected in another tab/window)")) | ||||||
|  |                 return self.executor.stage_invalid(" ".join(msg)) | ||||||
|         return super().dispatch(request, *args, **kwargs) |         return super().dispatch(request, *args, **kwargs) | ||||||
|  |  | ||||||
|     def get_challenge(self, *args, **kwargs) -> RedirectChallenge: |     def get_challenge(self, *args, **kwargs) -> RedirectChallenge: | ||||||
|  | |||||||
| @ -5,9 +5,9 @@ from celery.schedules import crontab | |||||||
| from authentik.lib.utils.time import fqdn_rand | from authentik.lib.utils.time import fqdn_rand | ||||||
|  |  | ||||||
| CELERY_BEAT_SCHEDULE = { | CELERY_BEAT_SCHEDULE = { | ||||||
|     "enterprise_calculate_license": { |     "enterprise_update_usage": { | ||||||
|         "task": "authentik.enterprise.tasks.calculate_license", |         "task": "authentik.enterprise.tasks.enterprise_update_usage", | ||||||
|         "schedule": crontab(minute=fqdn_rand("calculate_license"), hour="*/2"), |         "schedule": crontab(minute=fqdn_rand("enterprise_update_usage"), hour="*/2"), | ||||||
|         "options": {"queue": "authentik_scheduled"}, |         "options": {"queue": "authentik_scheduled"}, | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @ -16,3 +16,5 @@ TENANT_APPS = [ | |||||||
|     "authentik.enterprise.audit", |     "authentik.enterprise.audit", | ||||||
|     "authentik.enterprise.providers.rac", |     "authentik.enterprise.providers.rac", | ||||||
| ] | ] | ||||||
|  |  | ||||||
|  | MIDDLEWARE = ["authentik.enterprise.middleware.EnterpriseMiddleware"] | ||||||
|  | |||||||
| @ -2,11 +2,14 @@ | |||||||
|  |  | ||||||
| from datetime import datetime | from datetime import datetime | ||||||
|  |  | ||||||
| from django.db.models.signals import pre_save | from django.core.cache import cache | ||||||
|  | from django.db.models.signals import post_save, pre_save | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
| from django.utils.timezone import get_current_timezone | from django.utils.timezone import get_current_timezone | ||||||
|  |  | ||||||
|  | from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE | ||||||
| from authentik.enterprise.models import License | from authentik.enterprise.models import License | ||||||
|  | from authentik.enterprise.tasks import enterprise_update_usage | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_save, sender=License) | @receiver(pre_save, sender=License) | ||||||
| @ -17,3 +20,10 @@ def pre_save_license(sender: type[License], instance: License, **_): | |||||||
|     instance.internal_users = status.internal_users |     instance.internal_users = status.internal_users | ||||||
|     instance.external_users = status.external_users |     instance.external_users = status.external_users | ||||||
|     instance.expiry = datetime.fromtimestamp(status.exp, tz=get_current_timezone()) |     instance.expiry = datetime.fromtimestamp(status.exp, tz=get_current_timezone()) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(post_save, sender=License) | ||||||
|  | def post_save_license(sender: type[License], instance: License, **_): | ||||||
|  |     """Trigger license usage calculation when license is saved""" | ||||||
|  |     cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) | ||||||
|  |     enterprise_update_usage.delay() | ||||||
|  | |||||||
| @ -1,10 +1,14 @@ | |||||||
| """Enterprise tasks""" | """Enterprise tasks""" | ||||||
|  |  | ||||||
| from authentik.enterprise.models import LicenseKey | from authentik.enterprise.license import LicenseKey | ||||||
|  | from authentik.events.models import TaskStatus | ||||||
|  | from authentik.events.system_tasks import SystemTask, prefill_task | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task() | @CELERY_APP.task(bind=True, base=SystemTask) | ||||||
| def calculate_license(): | @prefill_task | ||||||
|     """Calculate licensing status""" | def enterprise_update_usage(self: SystemTask): | ||||||
|  |     """Update enterprise license status""" | ||||||
|     LicenseKey.get_total().record_usage() |     LicenseKey.get_total().record_usage() | ||||||
|  |     self.set_status(TaskStatus.SUCCESSFUL) | ||||||
|  | |||||||
| @ -8,7 +8,8 @@ from django.test import TestCase | |||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from rest_framework.exceptions import ValidationError | from rest_framework.exceptions import ValidationError | ||||||
|  |  | ||||||
| from authentik.enterprise.models import License, LicenseKey | from authentik.enterprise.license import LicenseKey | ||||||
|  | from authentik.enterprise.models import License | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
|  |  | ||||||
| _exp = int(mktime((now() + timedelta(days=3000)).timetuple())) | _exp = int(mktime((now() + timedelta(days=3000)).timetuple())) | ||||||
| @ -18,7 +19,7 @@ class TestEnterpriseLicense(TestCase): | |||||||
|     """Enterprise license tests""" |     """Enterprise license tests""" | ||||||
|  |  | ||||||
|     @patch( |     @patch( | ||||||
|         "authentik.enterprise.models.LicenseKey.validate", |         "authentik.enterprise.license.LicenseKey.validate", | ||||||
|         MagicMock( |         MagicMock( | ||||||
|             return_value=LicenseKey( |             return_value=LicenseKey( | ||||||
|                 aud="", |                 aud="", | ||||||
| @ -41,7 +42,7 @@ class TestEnterpriseLicense(TestCase): | |||||||
|             License.objects.create(key=generate_id()) |             License.objects.create(key=generate_id()) | ||||||
|  |  | ||||||
|     @patch( |     @patch( | ||||||
|         "authentik.enterprise.models.LicenseKey.validate", |         "authentik.enterprise.license.LicenseKey.validate", | ||||||
|         MagicMock( |         MagicMock( | ||||||
|             return_value=LicenseKey( |             return_value=LicenseKey( | ||||||
|                 aud="", |                 aud="", | ||||||
|  | |||||||
| @ -12,7 +12,6 @@ from rest_framework.response import Response | |||||||
| from rest_framework.serializers import ModelSerializer | from rest_framework.serializers import ModelSerializer | ||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
|  |  | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import PassiveSerializer | from authentik.core.api.utils import PassiveSerializer | ||||||
| from authentik.events.models import ( | from authentik.events.models import ( | ||||||
| @ -24,6 +23,7 @@ from authentik.events.models import ( | |||||||
|     TransportMode, |     TransportMode, | ||||||
| ) | ) | ||||||
| from authentik.events.utils import get_user | from authentik.events.utils import get_user | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
|  |  | ||||||
| class NotificationTransportSerializer(ModelSerializer): | class NotificationTransportSerializer(ModelSerializer): | ||||||
|  | |||||||
| @ -1,6 +1,5 @@ | |||||||
| """Tasks API""" | """Tasks API""" | ||||||
|  |  | ||||||
| from datetime import datetime, timezone |  | ||||||
| from importlib import import_module | from importlib import import_module | ||||||
|  |  | ||||||
| from django.contrib import messages | from django.contrib import messages | ||||||
| @ -8,15 +7,22 @@ from django.utils.translation import gettext_lazy as _ | |||||||
| from drf_spectacular.types import OpenApiTypes | from drf_spectacular.types import OpenApiTypes | ||||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema | from drf_spectacular.utils import OpenApiResponse, extend_schema | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
| from rest_framework.fields import CharField, ChoiceField, ListField, SerializerMethodField | from rest_framework.fields import ( | ||||||
|  |     CharField, | ||||||
|  |     ChoiceField, | ||||||
|  |     DateTimeField, | ||||||
|  |     FloatField, | ||||||
|  |     ListField, | ||||||
|  |     SerializerMethodField, | ||||||
|  | ) | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
| from rest_framework.serializers import ModelSerializer | from rest_framework.serializers import ModelSerializer | ||||||
| from rest_framework.viewsets import ReadOnlyModelViewSet | from rest_framework.viewsets import ReadOnlyModelViewSet | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.events.models import SystemTask, TaskStatus | from authentik.events.models import SystemTask, TaskStatus | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| @ -28,9 +34,9 @@ class SystemTaskSerializer(ModelSerializer): | |||||||
|     full_name = SerializerMethodField() |     full_name = SerializerMethodField() | ||||||
|     uid = CharField(required=False) |     uid = CharField(required=False) | ||||||
|     description = CharField() |     description = CharField() | ||||||
|     start_timestamp = SerializerMethodField() |     start_timestamp = DateTimeField(read_only=True) | ||||||
|     finish_timestamp = SerializerMethodField() |     finish_timestamp = DateTimeField(read_only=True) | ||||||
|     duration = SerializerMethodField() |     duration = FloatField(read_only=True) | ||||||
|  |  | ||||||
|     status = ChoiceField(choices=[(x.value, x.name) for x in TaskStatus]) |     status = ChoiceField(choices=[(x.value, x.name) for x in TaskStatus]) | ||||||
|     messages = ListField(child=CharField()) |     messages = ListField(child=CharField()) | ||||||
| @ -41,18 +47,6 @@ class SystemTaskSerializer(ModelSerializer): | |||||||
|             return f"{instance.name}:{instance.uid}" |             return f"{instance.name}:{instance.uid}" | ||||||
|         return instance.name |         return instance.name | ||||||
|  |  | ||||||
|     def get_start_timestamp(self, instance: SystemTask) -> datetime: |  | ||||||
|         """Timestamp when the task started""" |  | ||||||
|         return datetime.fromtimestamp(instance.start_timestamp, tz=timezone.utc) |  | ||||||
|  |  | ||||||
|     def get_finish_timestamp(self, instance: SystemTask) -> datetime: |  | ||||||
|         """Timestamp when the task finished""" |  | ||||||
|         return datetime.fromtimestamp(instance.finish_timestamp, tz=timezone.utc) |  | ||||||
|  |  | ||||||
|     def get_duration(self, instance: SystemTask) -> float: |  | ||||||
|         """Get the duration a task took to run""" |  | ||||||
|         return max(instance.finish_timestamp - instance.start_timestamp, 0) |  | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         model = SystemTask |         model = SystemTask | ||||||
|         fields = [ |         fields = [ | ||||||
| @ -87,7 +81,7 @@ class SystemTaskViewSet(ReadOnlyModelViewSet): | |||||||
|             500: OpenApiResponse(description="Failed to retry task"), |             500: OpenApiResponse(description="Failed to retry task"), | ||||||
|         }, |         }, | ||||||
|     ) |     ) | ||||||
|     @action(detail=True, methods=["post"]) |     @action(detail=True, methods=["POST"], permission_classes=[]) | ||||||
|     def run(self, request: Request, pk=None) -> Response: |     def run(self, request: Request, pk=None) -> Response: | ||||||
|         """Run task""" |         """Run task""" | ||||||
|         task: SystemTask = self.get_object() |         task: SystemTask = self.get_object() | ||||||
|  | |||||||
| @ -1,9 +1,12 @@ | |||||||
| """authentik events app""" | """authentik events app""" | ||||||
|  |  | ||||||
|  | from celery.schedules import crontab | ||||||
| from prometheus_client import Gauge, Histogram | from prometheus_client import Gauge, Histogram | ||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
| from authentik.lib.config import CONFIG, ENV_PREFIX | from authentik.lib.config import CONFIG, ENV_PREFIX | ||||||
|  | from authentik.lib.utils.reflection import path_to_class | ||||||
|  | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| # TODO: Deprecated metric - remove in 2024.2 or later | # TODO: Deprecated metric - remove in 2024.2 or later | ||||||
| GAUGE_TASKS = Gauge( | GAUGE_TASKS = Gauge( | ||||||
| @ -15,7 +18,7 @@ GAUGE_TASKS = Gauge( | |||||||
| SYSTEM_TASK_TIME = Histogram( | SYSTEM_TASK_TIME = Histogram( | ||||||
|     "authentik_system_tasks_time_seconds", |     "authentik_system_tasks_time_seconds", | ||||||
|     "Runtime of system tasks", |     "Runtime of system tasks", | ||||||
|     ["tenant"], |     ["tenant", "task_name", "task_uid"], | ||||||
| ) | ) | ||||||
| SYSTEM_TASK_STATUS = Gauge( | SYSTEM_TASK_STATUS = Gauge( | ||||||
|     "authentik_system_tasks_status", |     "authentik_system_tasks_status", | ||||||
| @ -32,10 +35,6 @@ class AuthentikEventsConfig(ManagedAppConfig): | |||||||
|     verbose_name = "authentik Events" |     verbose_name = "authentik Events" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_load_events_signals(self): |  | ||||||
|         """Load events signals""" |  | ||||||
|         self.import_module("authentik.events.signals") |  | ||||||
|  |  | ||||||
|     def reconcile_global_check_deprecations(self): |     def reconcile_global_check_deprecations(self): | ||||||
|         """Check for config deprecations""" |         """Check for config deprecations""" | ||||||
|         from authentik.events.models import Event, EventAction |         from authentik.events.models import Event, EventAction | ||||||
| @ -57,7 +56,7 @@ class AuthentikEventsConfig(ManagedAppConfig): | |||||||
|                 message=msg, |                 message=msg, | ||||||
|             ).save() |             ).save() | ||||||
|  |  | ||||||
|     def reconcile_prefill_tasks(self): |     def reconcile_tenant_prefill_tasks(self): | ||||||
|         """Prefill tasks""" |         """Prefill tasks""" | ||||||
|         from authentik.events.models import SystemTask |         from authentik.events.models import SystemTask | ||||||
|         from authentik.events.system_tasks import _prefill_tasks |         from authentik.events.system_tasks import _prefill_tasks | ||||||
| @ -67,3 +66,28 @@ class AuthentikEventsConfig(ManagedAppConfig): | |||||||
|                 continue |                 continue | ||||||
|             task.save() |             task.save() | ||||||
|             self.logger.debug("prefilled task", task_name=task.name) |             self.logger.debug("prefilled task", task_name=task.name) | ||||||
|  |  | ||||||
|  |     def reconcile_tenant_run_scheduled_tasks(self): | ||||||
|  |         """Run schedule tasks which are behind schedule (only applies | ||||||
|  |         to tasks of which we keep metrics)""" | ||||||
|  |         from authentik.events.models import TaskStatus | ||||||
|  |         from authentik.events.system_tasks import SystemTask as CelerySystemTask | ||||||
|  |  | ||||||
|  |         for task in CELERY_APP.conf["beat_schedule"].values(): | ||||||
|  |             schedule = task["schedule"] | ||||||
|  |             if not isinstance(schedule, crontab): | ||||||
|  |                 continue | ||||||
|  |             task_class: CelerySystemTask = path_to_class(task["task"]) | ||||||
|  |             if not isinstance(task_class, CelerySystemTask): | ||||||
|  |                 continue | ||||||
|  |             db_task = task_class.db() | ||||||
|  |             if not db_task: | ||||||
|  |                 continue | ||||||
|  |             due, _ = schedule.is_due(db_task.finish_timestamp) | ||||||
|  |             if due or db_task.status == TaskStatus.UNKNOWN: | ||||||
|  |                 self.logger.debug("Running past-due scheduled task", task=task["task"]) | ||||||
|  |                 task_class.apply_async( | ||||||
|  |                     args=task.get("args", None), | ||||||
|  |                     kwargs=task.get("kwargs", None), | ||||||
|  |                     **task.get("options", {}), | ||||||
|  |                 ) | ||||||
|  | |||||||
| @ -82,26 +82,29 @@ class AuditMiddleware: | |||||||
|  |  | ||||||
|         self.anonymous_user = get_anonymous_user() |         self.anonymous_user = get_anonymous_user() | ||||||
|  |  | ||||||
|  |     def get_user(self, request: HttpRequest) -> User: | ||||||
|  |         user = getattr(request, "user", self.anonymous_user) | ||||||
|  |         if not user.is_authenticated: | ||||||
|  |             return self.anonymous_user | ||||||
|  |         return user | ||||||
|  |  | ||||||
|     def connect(self, request: HttpRequest): |     def connect(self, request: HttpRequest): | ||||||
|         """Connect signal for automatic logging""" |         """Connect signal for automatic logging""" | ||||||
|         self._ensure_fallback_user() |         self._ensure_fallback_user() | ||||||
|         user = getattr(request, "user", self.anonymous_user) |  | ||||||
|         if not user.is_authenticated: |  | ||||||
|             user = self.anonymous_user |  | ||||||
|         if not hasattr(request, "request_id"): |         if not hasattr(request, "request_id"): | ||||||
|             return |             return | ||||||
|         post_save.connect( |         post_save.connect( | ||||||
|             partial(self.post_save_handler, user=user, request=request), |             partial(self.post_save_handler, request=request), | ||||||
|             dispatch_uid=request.request_id, |             dispatch_uid=request.request_id, | ||||||
|             weak=False, |             weak=False, | ||||||
|         ) |         ) | ||||||
|         pre_delete.connect( |         pre_delete.connect( | ||||||
|             partial(self.pre_delete_handler, user=user, request=request), |             partial(self.pre_delete_handler, request=request), | ||||||
|             dispatch_uid=request.request_id, |             dispatch_uid=request.request_id, | ||||||
|             weak=False, |             weak=False, | ||||||
|         ) |         ) | ||||||
|         m2m_changed.connect( |         m2m_changed.connect( | ||||||
|             partial(self.m2m_changed_handler, user=user, request=request), |             partial(self.m2m_changed_handler, request=request), | ||||||
|             dispatch_uid=request.request_id, |             dispatch_uid=request.request_id, | ||||||
|             weak=False, |             weak=False, | ||||||
|         ) |         ) | ||||||
| @ -147,7 +150,6 @@ class AuditMiddleware: | |||||||
|     # pylint: disable=too-many-arguments |     # pylint: disable=too-many-arguments | ||||||
|     def post_save_handler( |     def post_save_handler( | ||||||
|         self, |         self, | ||||||
|         user: User, |  | ||||||
|         request: HttpRequest, |         request: HttpRequest, | ||||||
|         sender, |         sender, | ||||||
|         instance: Model, |         instance: Model, | ||||||
| @ -158,16 +160,18 @@ class AuditMiddleware: | |||||||
|         """Signal handler for all object's post_save""" |         """Signal handler for all object's post_save""" | ||||||
|         if not should_log_model(instance): |         if not should_log_model(instance): | ||||||
|             return |             return | ||||||
|  |         user = self.get_user(request) | ||||||
|  |  | ||||||
|         action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED |         action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED | ||||||
|         thread = EventNewThread(action, request, user=user, model=model_to_dict(instance)) |         thread = EventNewThread(action, request, user=user, model=model_to_dict(instance)) | ||||||
|         thread.kwargs.update(thread_kwargs or {}) |         thread.kwargs.update(thread_kwargs or {}) | ||||||
|         thread.run() |         thread.run() | ||||||
|  |  | ||||||
|     def pre_delete_handler(self, user: User, request: HttpRequest, sender, instance: Model, **_): |     def pre_delete_handler(self, request: HttpRequest, sender, instance: Model, **_): | ||||||
|         """Signal handler for all object's pre_delete""" |         """Signal handler for all object's pre_delete""" | ||||||
|         if not should_log_model(instance):  # pragma: no cover |         if not should_log_model(instance):  # pragma: no cover | ||||||
|             return |             return | ||||||
|  |         user = self.get_user(request) | ||||||
|  |  | ||||||
|         EventNewThread( |         EventNewThread( | ||||||
|             EventAction.MODEL_DELETED, |             EventAction.MODEL_DELETED, | ||||||
| @ -176,14 +180,13 @@ class AuditMiddleware: | |||||||
|             model=model_to_dict(instance), |             model=model_to_dict(instance), | ||||||
|         ).run() |         ).run() | ||||||
|  |  | ||||||
|     def m2m_changed_handler( |     def m2m_changed_handler(self, request: HttpRequest, sender, instance: Model, action: str, **_): | ||||||
|         self, user: User, request: HttpRequest, sender, instance: Model, action: str, **_ |  | ||||||
|     ): |  | ||||||
|         """Signal handler for all object's m2m_changed""" |         """Signal handler for all object's m2m_changed""" | ||||||
|         if action not in ["pre_add", "pre_remove", "post_clear"]: |         if action not in ["pre_add", "pre_remove", "post_clear"]: | ||||||
|             return |             return | ||||||
|         if not should_log_m2m(instance): |         if not should_log_m2m(instance): | ||||||
|             return |             return | ||||||
|  |         user = self.get_user(request) | ||||||
|  |  | ||||||
|         EventNewThread( |         EventNewThread( | ||||||
|             EventAction.MODEL_UPDATED, |             EventAction.MODEL_UPDATED, | ||||||
|  | |||||||
| @ -0,0 +1,68 @@ | |||||||
|  | # Generated by Django 5.0.1 on 2024-02-07 15:42 | ||||||
|  |  | ||||||
|  | import uuid | ||||||
|  |  | ||||||
|  | import django.utils.timezone | ||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  | import authentik.core.models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     replaces = [ | ||||||
|  |         ("authentik_events", "0004_systemtask"), | ||||||
|  |         ("authentik_events", "0005_remove_systemtask_finish_timestamp_and_more"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_events", "0003_rename_tenant_event_brand"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.CreateModel( | ||||||
|  |             name="SystemTask", | ||||||
|  |             fields=[ | ||||||
|  |                 ( | ||||||
|  |                     "expires", | ||||||
|  |                     models.DateTimeField(default=authentik.core.models.default_token_duration), | ||||||
|  |                 ), | ||||||
|  |                 ("expiring", models.BooleanField(default=True)), | ||||||
|  |                 ( | ||||||
|  |                     "uuid", | ||||||
|  |                     models.UUIDField( | ||||||
|  |                         default=uuid.uuid4, editable=False, primary_key=True, serialize=False | ||||||
|  |                     ), | ||||||
|  |                 ), | ||||||
|  |                 ("name", models.TextField()), | ||||||
|  |                 ("uid", models.TextField(null=True)), | ||||||
|  |                 ( | ||||||
|  |                     "status", | ||||||
|  |                     models.TextField( | ||||||
|  |                         choices=[ | ||||||
|  |                             ("unknown", "Unknown"), | ||||||
|  |                             ("successful", "Successful"), | ||||||
|  |                             ("warning", "Warning"), | ||||||
|  |                             ("error", "Error"), | ||||||
|  |                         ] | ||||||
|  |                     ), | ||||||
|  |                 ), | ||||||
|  |                 ("description", models.TextField(null=True)), | ||||||
|  |                 ("messages", models.JSONField()), | ||||||
|  |                 ("task_call_module", models.TextField()), | ||||||
|  |                 ("task_call_func", models.TextField()), | ||||||
|  |                 ("task_call_args", models.JSONField(default=list)), | ||||||
|  |                 ("task_call_kwargs", models.JSONField(default=dict)), | ||||||
|  |                 ("duration", models.FloatField(default=0)), | ||||||
|  |                 ("finish_timestamp", models.DateTimeField(default=django.utils.timezone.now)), | ||||||
|  |                 ("start_timestamp", models.DateTimeField(default=django.utils.timezone.now)), | ||||||
|  |             ], | ||||||
|  |             options={ | ||||||
|  |                 "verbose_name": "System Task", | ||||||
|  |                 "verbose_name_plural": "System Tasks", | ||||||
|  |                 "permissions": [("run_task", "Run task")], | ||||||
|  |                 "default_permissions": ["view"], | ||||||
|  |                 "unique_together": {("name", "uid")}, | ||||||
|  |             }, | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -0,0 +1,37 @@ | |||||||
|  | # Generated by Django 5.0.1 on 2024-02-06 18:02 | ||||||
|  |  | ||||||
|  | import django.utils.timezone | ||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_events", "0004_systemtask"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.RemoveField( | ||||||
|  |             model_name="systemtask", | ||||||
|  |             name="finish_timestamp", | ||||||
|  |         ), | ||||||
|  |         migrations.RemoveField( | ||||||
|  |             model_name="systemtask", | ||||||
|  |             name="start_timestamp", | ||||||
|  |         ), | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="systemtask", | ||||||
|  |             name="duration", | ||||||
|  |             field=models.FloatField(default=0), | ||||||
|  |         ), | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="systemtask", | ||||||
|  |             name="finish_timestamp", | ||||||
|  |             field=models.DateTimeField(default=django.utils.timezone.now), | ||||||
|  |         ), | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="systemtask", | ||||||
|  |             name="start_timestamp", | ||||||
|  |             field=models.DateTimeField(default=django.utils.timezone.now), | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -451,6 +451,13 @@ class NotificationTransport(SerializerModel): | |||||||
|  |  | ||||||
|     def send_email(self, notification: "Notification") -> list[str]: |     def send_email(self, notification: "Notification") -> list[str]: | ||||||
|         """Send notification via global email configuration""" |         """Send notification via global email configuration""" | ||||||
|  |         if notification.user.email.strip() == "": | ||||||
|  |             LOGGER.info( | ||||||
|  |                 "Discarding notification as user has no email address", | ||||||
|  |                 user=notification.user, | ||||||
|  |                 notification=notification, | ||||||
|  |             ) | ||||||
|  |             return None | ||||||
|         subject_prefix = "authentik Notification: " |         subject_prefix = "authentik Notification: " | ||||||
|         context = { |         context = { | ||||||
|             "key_value": { |             "key_value": { | ||||||
| @ -480,7 +487,7 @@ class NotificationTransport(SerializerModel): | |||||||
|             } |             } | ||||||
|         mail = TemplateEmailMessage( |         mail = TemplateEmailMessage( | ||||||
|             subject=subject_prefix + context["title"], |             subject=subject_prefix + context["title"], | ||||||
|             to=[f"{notification.user.name} <{notification.user.email}>"], |             to=[(notification.user.name, notification.user.email)], | ||||||
|             language=notification.user.locale(), |             language=notification.user.locale(), | ||||||
|             template_name="email/event_notification.html", |             template_name="email/event_notification.html", | ||||||
|             template_context=context, |             template_context=context, | ||||||
| @ -620,8 +627,9 @@ class SystemTask(SerializerModel, ExpiringModel): | |||||||
|     name = models.TextField() |     name = models.TextField() | ||||||
|     uid = models.TextField(null=True) |     uid = models.TextField(null=True) | ||||||
|  |  | ||||||
|     start_timestamp = models.FloatField() |     start_timestamp = models.DateTimeField(default=now) | ||||||
|     finish_timestamp = models.FloatField() |     finish_timestamp = models.DateTimeField(default=now) | ||||||
|  |     duration = models.FloatField(default=0) | ||||||
|  |  | ||||||
|     status = models.TextField(choices=TaskStatus.choices) |     status = models.TextField(choices=TaskStatus.choices) | ||||||
|  |  | ||||||
| @ -641,17 +649,18 @@ class SystemTask(SerializerModel, ExpiringModel): | |||||||
|  |  | ||||||
|     def update_metrics(self): |     def update_metrics(self): | ||||||
|         """Update prometheus metrics""" |         """Update prometheus metrics""" | ||||||
|         duration = max(self.finish_timestamp - self.start_timestamp, 0) |  | ||||||
|         # TODO: Deprecated metric - remove in 2024.2 or later |         # TODO: Deprecated metric - remove in 2024.2 or later | ||||||
|         GAUGE_TASKS.labels( |         GAUGE_TASKS.labels( | ||||||
|             tenant=connection.schema_name, |             tenant=connection.schema_name, | ||||||
|             task_name=self.name, |             task_name=self.name, | ||||||
|             task_uid=self.uid or "", |             task_uid=self.uid or "", | ||||||
|             status=self.status.lower(), |             status=self.status.lower(), | ||||||
|         ).set(duration) |         ).set(self.duration) | ||||||
|         SYSTEM_TASK_TIME.labels( |         SYSTEM_TASK_TIME.labels( | ||||||
|             tenant=connection.schema_name, |             tenant=connection.schema_name, | ||||||
|         ).observe(duration) |             task_name=self.name, | ||||||
|  |             task_uid=self.uid or "", | ||||||
|  |         ).observe(self.duration) | ||||||
|         SYSTEM_TASK_STATUS.labels( |         SYSTEM_TASK_STATUS.labels( | ||||||
|             tenant=connection.schema_name, |             tenant=connection.schema_name, | ||||||
|             task_name=self.name, |             task_name=self.name, | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| """Monitored tasks""" | """Monitored tasks""" | ||||||
|  |  | ||||||
| from datetime import timedelta | from datetime import datetime, timedelta | ||||||
| from timeit import default_timer | from time import perf_counter | ||||||
| from typing import Any, Optional | from typing import Any, Optional | ||||||
|  |  | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| @ -24,14 +24,17 @@ class SystemTask(TenantTask): | |||||||
|     # For tasks that should only be listed if they failed, set this to False |     # For tasks that should only be listed if they failed, set this to False | ||||||
|     save_on_success: bool |     save_on_success: bool | ||||||
|  |  | ||||||
|     _status: Optional[TaskStatus] |     _status: TaskStatus | ||||||
|     _messages: list[str] |     _messages: list[str] | ||||||
|  |  | ||||||
|     _uid: Optional[str] |     _uid: Optional[str] | ||||||
|     _start: Optional[float] = None |     # Precise start time from perf_counter | ||||||
|  |     _start_precise: Optional[float] = None | ||||||
|  |     _start: Optional[datetime] = None | ||||||
|  |  | ||||||
|     def __init__(self, *args, **kwargs) -> None: |     def __init__(self, *args, **kwargs) -> None: | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
|  |         self._status = TaskStatus.SUCCESSFUL | ||||||
|         self.save_on_success = True |         self.save_on_success = True | ||||||
|         self._uid = None |         self._uid = None | ||||||
|         self._status = None |         self._status = None | ||||||
| @ -53,9 +56,17 @@ class SystemTask(TenantTask): | |||||||
|         self._messages = [exception_to_string(exception)] |         self._messages = [exception_to_string(exception)] | ||||||
|  |  | ||||||
|     def before_start(self, task_id, args, kwargs): |     def before_start(self, task_id, args, kwargs): | ||||||
|         self._start = default_timer() |         self._start_precise = perf_counter() | ||||||
|  |         self._start = now() | ||||||
|         return super().before_start(task_id, args, kwargs) |         return super().before_start(task_id, args, kwargs) | ||||||
|  |  | ||||||
|  |     def db(self) -> Optional[DBSystemTask]: | ||||||
|  |         """Get DB object for latest task""" | ||||||
|  |         return DBSystemTask.objects.filter( | ||||||
|  |             name=self.__name__, | ||||||
|  |             uid=self._uid, | ||||||
|  |         ).first() | ||||||
|  |  | ||||||
|     # pylint: disable=too-many-arguments |     # pylint: disable=too-many-arguments | ||||||
|     def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo): |     def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo): | ||||||
|         super().after_return(status, retval, task_id, args, kwargs, einfo=einfo) |         super().after_return(status, retval, task_id, args, kwargs, einfo=einfo) | ||||||
| @ -72,12 +83,13 @@ class SystemTask(TenantTask): | |||||||
|             uid=self._uid, |             uid=self._uid, | ||||||
|             defaults={ |             defaults={ | ||||||
|                 "description": self.__doc__, |                 "description": self.__doc__, | ||||||
|                 "start_timestamp": self._start or default_timer(), |                 "start_timestamp": self._start or now(), | ||||||
|                 "finish_timestamp": default_timer(), |                 "finish_timestamp": now(), | ||||||
|  |                 "duration": max(perf_counter() - self._start_precise, 0), | ||||||
|                 "task_call_module": self.__module__, |                 "task_call_module": self.__module__, | ||||||
|                 "task_call_func": self.__name__, |                 "task_call_func": self.__name__, | ||||||
|                 "task_call_args": args, |                 "task_call_args": sanitize_item(args), | ||||||
|                 "task_call_kwargs": kwargs, |                 "task_call_kwargs": sanitize_item(kwargs), | ||||||
|                 "status": self._status, |                 "status": self._status, | ||||||
|                 "messages": sanitize_item(self._messages), |                 "messages": sanitize_item(self._messages), | ||||||
|                 "expires": now() + timedelta(hours=self.result_timeout_hours), |                 "expires": now() + timedelta(hours=self.result_timeout_hours), | ||||||
| @ -96,12 +108,13 @@ class SystemTask(TenantTask): | |||||||
|             uid=self._uid, |             uid=self._uid, | ||||||
|             defaults={ |             defaults={ | ||||||
|                 "description": self.__doc__, |                 "description": self.__doc__, | ||||||
|                 "start_timestamp": self._start or default_timer(), |                 "start_timestamp": self._start or now(), | ||||||
|                 "finish_timestamp": default_timer(), |                 "finish_timestamp": now(), | ||||||
|  |                 "duration": max(perf_counter() - self._start_precise, 0), | ||||||
|                 "task_call_module": self.__module__, |                 "task_call_module": self.__module__, | ||||||
|                 "task_call_func": self.__name__, |                 "task_call_func": self.__name__, | ||||||
|                 "task_call_args": args, |                 "task_call_args": sanitize_item(args), | ||||||
|                 "task_call_kwargs": kwargs, |                 "task_call_kwargs": sanitize_item(kwargs), | ||||||
|                 "status": self._status, |                 "status": self._status, | ||||||
|                 "messages": sanitize_item(self._messages), |                 "messages": sanitize_item(self._messages), | ||||||
|                 "expires": now() + timedelta(hours=self.result_timeout_hours), |                 "expires": now() + timedelta(hours=self.result_timeout_hours), | ||||||
| @ -123,11 +136,14 @@ def prefill_task(func): | |||||||
|         DBSystemTask( |         DBSystemTask( | ||||||
|             name=func.__name__, |             name=func.__name__, | ||||||
|             description=func.__doc__, |             description=func.__doc__, | ||||||
|  |             start_timestamp=now(), | ||||||
|  |             finish_timestamp=now(), | ||||||
|             status=TaskStatus.UNKNOWN, |             status=TaskStatus.UNKNOWN, | ||||||
|             messages=sanitize_item([_("Task has not been run yet.")]), |             messages=sanitize_item([_("Task has not been run yet.")]), | ||||||
|             task_call_module=func.__module__, |             task_call_module=func.__module__, | ||||||
|             task_call_func=func.__name__, |             task_call_func=func.__name__, | ||||||
|             expiring=False, |             expiring=False, | ||||||
|  |             duration=0, | ||||||
|         ) |         ) | ||||||
|     ) |     ) | ||||||
|     return func |     return func | ||||||
|  | |||||||
| @ -3,9 +3,10 @@ | |||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from rest_framework.test import APITestCase | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
| from authentik.core.models import Application | from authentik.core.models import Application, Token, TokenIntents | ||||||
| from authentik.core.tests.utils import create_test_admin_user | from authentik.core.tests.utils import create_test_admin_user | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
|  | from authentik.lib.generators import generate_id | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestEventsMiddleware(APITestCase): | class TestEventsMiddleware(APITestCase): | ||||||
| @ -47,3 +48,30 @@ class TestEventsMiddleware(APITestCase): | |||||||
|                 context__model__name="test-delete", |                 context__model__name="test-delete", | ||||||
|             ).exists() |             ).exists() | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_create_with_api(self): | ||||||
|  |         """Test model creation event (with API token auth)""" | ||||||
|  |         self.client.logout() | ||||||
|  |         token = Token.objects.create(user=self.user, intent=TokenIntents.INTENT_API, expiring=False) | ||||||
|  |         uid = generate_id() | ||||||
|  |         self.client.post( | ||||||
|  |             reverse("authentik_api:application-list"), | ||||||
|  |             data={"name": uid, "slug": uid}, | ||||||
|  |             HTTP_AUTHORIZATION=f"Bearer {token.key}", | ||||||
|  |         ) | ||||||
|  |         self.assertTrue(Application.objects.filter(name=uid).exists()) | ||||||
|  |         event = Event.objects.filter( | ||||||
|  |             action=EventAction.MODEL_CREATED, | ||||||
|  |             context__model__model_name="application", | ||||||
|  |             context__model__app="authentik_core", | ||||||
|  |             context__model__name=uid, | ||||||
|  |         ).first() | ||||||
|  |         self.assertIsNotNone(event) | ||||||
|  |         self.assertEqual( | ||||||
|  |             event.user, | ||||||
|  |             { | ||||||
|  |                 "pk": self.user.pk, | ||||||
|  |                 "email": self.user.email, | ||||||
|  |                 "username": self.user.username, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  | |||||||
| @ -15,7 +15,6 @@ from rest_framework.serializers import ModelSerializer, SerializerMethodField | |||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.blueprints.v1.exporter import FlowExporter | from authentik.blueprints.v1.exporter import FlowExporter | ||||||
| from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT, Importer | from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT, Importer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| @ -33,6 +32,7 @@ from authentik.lib.utils.file import ( | |||||||
|     set_file_url, |     set_file_url, | ||||||
| ) | ) | ||||||
| from authentik.lib.views import bad_request_message | from authentik.lib.views import bad_request_message | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  | |||||||
| @ -31,10 +31,6 @@ class AuthentikFlowsConfig(ManagedAppConfig): | |||||||
|     verbose_name = "authentik Flows" |     verbose_name = "authentik Flows" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_load_flows_signals(self): |  | ||||||
|         """Load flows signals""" |  | ||||||
|         self.import_module("authentik.flows.signals") |  | ||||||
|  |  | ||||||
|     def reconcile_global_load_stages(self): |     def reconcile_global_load_stages(self): | ||||||
|         """Ensure all stages are loaded""" |         """Ensure all stages are loaded""" | ||||||
|         from authentik.flows.models import Stage |         from authentik.flows.models import Stage | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| """flow views tests""" | """flow views tests""" | ||||||
|  |  | ||||||
| from unittest.mock import MagicMock, PropertyMock, patch | from unittest.mock import MagicMock, PropertyMock, patch | ||||||
|  | from urllib.parse import urlencode | ||||||
|  |  | ||||||
| from django.http import HttpRequest, HttpResponse | from django.http import HttpRequest, HttpResponse | ||||||
| from django.test.client import RequestFactory | from django.test.client import RequestFactory | ||||||
| @ -18,7 +19,12 @@ from authentik.flows.models import ( | |||||||
| from authentik.flows.planner import FlowPlan, FlowPlanner | from authentik.flows.planner import FlowPlan, FlowPlanner | ||||||
| from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView | from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView | ||||||
| from authentik.flows.tests import FlowTestCase | from authentik.flows.tests import FlowTestCase | ||||||
| from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView | from authentik.flows.views.executor import ( | ||||||
|  |     NEXT_ARG_NAME, | ||||||
|  |     QS_QUERY, | ||||||
|  |     SESSION_KEY_PLAN, | ||||||
|  |     FlowExecutorView, | ||||||
|  | ) | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.policies.dummy.models import DummyPolicy | from authentik.policies.dummy.models import DummyPolicy | ||||||
| from authentik.policies.models import PolicyBinding | from authentik.policies.models import PolicyBinding | ||||||
| @ -121,16 +127,73 @@ class TestFlowExecutor(FlowTestCase): | |||||||
|         TO_STAGE_RESPONSE_MOCK, |         TO_STAGE_RESPONSE_MOCK, | ||||||
|     ) |     ) | ||||||
|     def test_invalid_flow_redirect(self): |     def test_invalid_flow_redirect(self): | ||||||
|         """Tests that an invalid flow still redirects""" |         """Test invalid flow with valid redirect destination""" | ||||||
|         flow = create_test_flow( |         flow = create_test_flow( | ||||||
|             FlowDesignation.AUTHENTICATION, |             FlowDesignation.AUTHENTICATION, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         dest = "/unique-string" |         dest = "/unique-string" | ||||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) |         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||||
|         response = self.client.get(url + f"?{NEXT_ARG_NAME}={dest}") |         response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}") | ||||||
|         self.assertEqual(response.status_code, 302) |         self.assertEqual(response.status_code, 302) | ||||||
|         self.assertEqual(response.url, reverse("authentik_core:root-redirect")) |         self.assertEqual(response.url, "/unique-string") | ||||||
|  |  | ||||||
|  |     @patch( | ||||||
|  |         "authentik.flows.views.executor.to_stage_response", | ||||||
|  |         TO_STAGE_RESPONSE_MOCK, | ||||||
|  |     ) | ||||||
|  |     def test_invalid_flow_invalid_redirect(self): | ||||||
|  |         """Test invalid flow redirect with an invalid URL""" | ||||||
|  |         flow = create_test_flow( | ||||||
|  |             FlowDesignation.AUTHENTICATION, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         dest = "http://something.example.com/unique-string" | ||||||
|  |         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||||
|  |  | ||||||
|  |         response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}") | ||||||
|  |         self.assertEqual(response.status_code, 200) | ||||||
|  |         self.assertStageResponse( | ||||||
|  |             response, | ||||||
|  |             flow, | ||||||
|  |             component="ak-stage-access-denied", | ||||||
|  |             error_message="Invalid next URL", | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     @patch( | ||||||
|  |         "authentik.flows.views.executor.to_stage_response", | ||||||
|  |         TO_STAGE_RESPONSE_MOCK, | ||||||
|  |     ) | ||||||
|  |     def test_valid_flow_redirect(self): | ||||||
|  |         """Test valid flow with valid redirect destination""" | ||||||
|  |         flow = create_test_flow() | ||||||
|  |  | ||||||
|  |         dest = "/unique-string" | ||||||
|  |         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||||
|  |  | ||||||
|  |         response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}") | ||||||
|  |         self.assertEqual(response.status_code, 302) | ||||||
|  |         self.assertEqual(response.url, "/unique-string") | ||||||
|  |  | ||||||
|  |     @patch( | ||||||
|  |         "authentik.flows.views.executor.to_stage_response", | ||||||
|  |         TO_STAGE_RESPONSE_MOCK, | ||||||
|  |     ) | ||||||
|  |     def test_valid_flow_invalid_redirect(self): | ||||||
|  |         """Test valid flow redirect with an invalid URL""" | ||||||
|  |         flow = create_test_flow() | ||||||
|  |  | ||||||
|  |         dest = "http://something.example.com/unique-string" | ||||||
|  |         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||||
|  |  | ||||||
|  |         response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}") | ||||||
|  |         self.assertEqual(response.status_code, 200) | ||||||
|  |         self.assertStageResponse( | ||||||
|  |             response, | ||||||
|  |             flow, | ||||||
|  |             component="ak-stage-access-denied", | ||||||
|  |             error_message="Invalid next URL", | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     @patch( |     @patch( | ||||||
|         "authentik.flows.views.executor.to_stage_response", |         "authentik.flows.views.executor.to_stage_response", | ||||||
|  | |||||||
| @ -12,6 +12,7 @@ from django.shortcuts import get_object_or_404, redirect | |||||||
| from django.template.response import TemplateResponse | from django.template.response import TemplateResponse | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from django.utils.decorators import method_decorator | from django.utils.decorators import method_decorator | ||||||
|  | from django.utils.translation import gettext as _ | ||||||
| from django.views.decorators.clickjacking import xframe_options_sameorigin | from django.views.decorators.clickjacking import xframe_options_sameorigin | ||||||
| from django.views.generic import View | from django.views.generic import View | ||||||
| from drf_spectacular.types import OpenApiTypes | from drf_spectacular.types import OpenApiTypes | ||||||
| @ -178,6 +179,8 @@ class FlowExecutorView(APIView): | |||||||
|                     self.cancel() |                     self.cancel() | ||||||
|                 self._logger.debug("f(exec): Continuing existing plan") |                 self._logger.debug("f(exec): Continuing existing plan") | ||||||
|  |  | ||||||
|  |             # Initial flow request, check if we have an upstream query string passed in | ||||||
|  |             request.session[SESSION_KEY_GET] = get_params | ||||||
|             # Don't check session again as we've either already loaded the plan or we need to plan |             # Don't check session again as we've either already loaded the plan or we need to plan | ||||||
|             if not self.plan: |             if not self.plan: | ||||||
|                 request.session[SESSION_KEY_HISTORY] = [] |                 request.session[SESSION_KEY_HISTORY] = [] | ||||||
| @ -192,8 +195,6 @@ class FlowExecutorView(APIView): | |||||||
|                     # To match behaviour with loading an empty flow plan from cache, |                     # To match behaviour with loading an empty flow plan from cache, | ||||||
|                     # we don't show an error message here, but rather call _flow_done() |                     # we don't show an error message here, but rather call _flow_done() | ||||||
|                     return self._flow_done() |                     return self._flow_done() | ||||||
|             # Initial flow request, check if we have an upstream query string passed in |  | ||||||
|             request.session[SESSION_KEY_GET] = get_params |  | ||||||
|             # We don't save the Plan after getting the next stage |             # We don't save the Plan after getting the next stage | ||||||
|             # as it hasn't been successfully passed yet |             # as it hasn't been successfully passed yet | ||||||
|             try: |             try: | ||||||
| @ -392,7 +393,11 @@ class FlowExecutorView(APIView): | |||||||
|             NEXT_ARG_NAME, "authentik_core:root-redirect" |             NEXT_ARG_NAME, "authentik_core:root-redirect" | ||||||
|         ) |         ) | ||||||
|         self.cancel() |         self.cancel() | ||||||
|         return to_stage_response(self.request, redirect_with_qs(next_param)) |         if next_param and not is_url_absolute(next_param): | ||||||
|  |             return to_stage_response(self.request, redirect_with_qs(next_param)) | ||||||
|  |         return to_stage_response( | ||||||
|  |             self.request, self.stage_invalid(error_message=_("Invalid next URL")) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def stage_ok(self) -> HttpResponse: |     def stage_ok(self) -> HttpResponse: | ||||||
|         """Callback called by stages upon successful completion. |         """Callback called by stages upon successful completion. | ||||||
|  | |||||||
| @ -30,10 +30,6 @@ class AuthentikOutpostConfig(ManagedAppConfig): | |||||||
|     verbose_name = "authentik Outpost" |     verbose_name = "authentik Outpost" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_load_outposts_signals(self): |  | ||||||
|         """Load outposts signals""" |  | ||||||
|         self.import_module("authentik.outposts.signals") |  | ||||||
|  |  | ||||||
|     def reconcile_tenant_embedded_outpost(self): |     def reconcile_tenant_embedded_outpost(self): | ||||||
|         """Ensure embedded outpost""" |         """Ensure embedded outpost""" | ||||||
|         from authentik.outposts.models import ( |         from authentik.outposts.models import ( | ||||||
|  | |||||||
| @ -13,7 +13,6 @@ from rest_framework.viewsets import GenericViewSet | |||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| from structlog.testing import capture_logs | from structlog.testing import capture_logs | ||||||
|  |  | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.core.api.applications import user_app_cache_key | from authentik.core.api.applications import user_app_cache_key | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import CacheSerializer, MetaNameSerializer, TypeCreateSerializer | from authentik.core.api.utils import CacheSerializer, MetaNameSerializer, TypeCreateSerializer | ||||||
| @ -23,6 +22,7 @@ from authentik.policies.api.exec import PolicyTestResultSerializer, PolicyTestSe | |||||||
| from authentik.policies.models import Policy, PolicyBinding | from authentik.policies.models import Policy, PolicyBinding | ||||||
| from authentik.policies.process import PolicyProcess | from authentik.policies.process import PolicyProcess | ||||||
| from authentik.policies.types import CACHE_PREFIX, PolicyRequest | from authentik.policies.types import CACHE_PREFIX, PolicyRequest | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  | |||||||
| @ -35,7 +35,3 @@ class AuthentikPoliciesConfig(ManagedAppConfig): | |||||||
|     label = "authentik_policies" |     label = "authentik_policies" | ||||||
|     verbose_name = "authentik Policies" |     verbose_name = "authentik Policies" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_load_policies_signals(self): |  | ||||||
|         """Load policies signals""" |  | ||||||
|         self.import_module("authentik.policies.signals") |  | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
|  |  | ||||||
| from multiprocessing import Pipe, current_process | from multiprocessing import Pipe, current_process | ||||||
| from multiprocessing.connection import Connection | from multiprocessing.connection import Connection | ||||||
| from timeit import default_timer | from time import perf_counter | ||||||
| from typing import Iterator, Optional | from typing import Iterator, Optional | ||||||
|  |  | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| @ -84,10 +84,10 @@ class PolicyEngine: | |||||||
|     def _check_cache(self, binding: PolicyBinding): |     def _check_cache(self, binding: PolicyBinding): | ||||||
|         if not self.use_cache: |         if not self.use_cache: | ||||||
|             return False |             return False | ||||||
|         before = default_timer() |         before = perf_counter() | ||||||
|         key = cache_key(binding, self.request) |         key = cache_key(binding, self.request) | ||||||
|         cached_policy = cache.get(key, None) |         cached_policy = cache.get(key, None) | ||||||
|         duration = max(default_timer() - before, 0) |         duration = max(perf_counter() - before, 0) | ||||||
|         if not cached_policy: |         if not cached_policy: | ||||||
|             return False |             return False | ||||||
|         self.logger.debug( |         self.logger.debug( | ||||||
|  | |||||||
| @ -2,6 +2,8 @@ | |||||||
|  |  | ||||||
| from authentik.blueprints.apps import ManagedAppConfig | from authentik.blueprints.apps import ManagedAppConfig | ||||||
|  |  | ||||||
|  | CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/" | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikPolicyReputationConfig(ManagedAppConfig): | class AuthentikPolicyReputationConfig(ManagedAppConfig): | ||||||
|     """Authentik reputation app config""" |     """Authentik reputation app config""" | ||||||
| @ -10,11 +12,3 @@ class AuthentikPolicyReputationConfig(ManagedAppConfig): | |||||||
|     label = "authentik_policies_reputation" |     label = "authentik_policies_reputation" | ||||||
|     verbose_name = "authentik Policies.Reputation" |     verbose_name = "authentik Policies.Reputation" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_load_policies_reputation_signals(self): |  | ||||||
|         """Load policies.reputation signals""" |  | ||||||
|         self.import_module("authentik.policies.reputation.signals") |  | ||||||
|  |  | ||||||
|     def reconcile_global_load_policies_reputation_tasks(self): |  | ||||||
|         """Load policies.reputation tasks""" |  | ||||||
|         self.import_module("authentik.policies.reputation.tasks") |  | ||||||
|  | |||||||
| @ -19,7 +19,6 @@ from authentik.policies.types import PolicyRequest, PolicyResult | |||||||
| from authentik.root.middleware import ClientIPMiddleware | from authentik.root.middleware import ClientIPMiddleware | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def reputation_expiry(): | def reputation_expiry(): | ||||||
|  | |||||||
| @ -8,7 +8,7 @@ from structlog.stdlib import get_logger | |||||||
|  |  | ||||||
| from authentik.core.signals import login_failed | from authentik.core.signals import login_failed | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.policies.reputation.models import CACHE_KEY_PREFIX | from authentik.policies.reputation.apps import CACHE_KEY_PREFIX | ||||||
| from authentik.policies.reputation.tasks import save_reputation | from authentik.policies.reputation.tasks import save_reputation | ||||||
| from authentik.root.middleware import ClientIPMiddleware | from authentik.root.middleware import ClientIPMiddleware | ||||||
| from authentik.stages.identification.signals import identification_failed | from authentik.stages.identification.signals import identification_failed | ||||||
|  | |||||||
| @ -7,8 +7,8 @@ from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR | |||||||
| from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR | from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR | ||||||
| from authentik.events.models import TaskStatus | from authentik.events.models import TaskStatus | ||||||
| from authentik.events.system_tasks import SystemTask, prefill_task | from authentik.events.system_tasks import SystemTask, prefill_task | ||||||
|  | from authentik.policies.reputation.apps import CACHE_KEY_PREFIX | ||||||
| from authentik.policies.reputation.models import Reputation | from authentik.policies.reputation.models import Reputation | ||||||
| from authentik.policies.reputation.signals import CACHE_KEY_PREFIX |  | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  | |||||||
| @ -6,7 +6,8 @@ from django.test import RequestFactory, TestCase | |||||||
| from authentik.core.models import User | from authentik.core.models import User | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.policies.reputation.api import ReputationPolicySerializer | from authentik.policies.reputation.api import ReputationPolicySerializer | ||||||
| from authentik.policies.reputation.models import CACHE_KEY_PREFIX, Reputation, ReputationPolicy | from authentik.policies.reputation.apps import CACHE_KEY_PREFIX | ||||||
|  | from authentik.policies.reputation.models import Reputation, ReputationPolicy | ||||||
| from authentik.policies.reputation.tasks import save_reputation | from authentik.policies.reputation.tasks import save_reputation | ||||||
| from authentik.policies.types import PolicyRequest | from authentik.policies.types import PolicyRequest | ||||||
| from authentik.stages.password import BACKEND_INBUILT | from authentik.stages.password import BACKEND_INBUILT | ||||||
|  | |||||||
| @ -15,13 +15,13 @@ from rest_framework.request import Request | |||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
|  |  | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.core.api.providers import ProviderSerializer | from authentik.core.api.providers import ProviderSerializer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer | from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer | ||||||
| from authentik.core.models import Provider | from authentik.core.models import Provider | ||||||
| from authentik.providers.oauth2.id_token import IDToken | from authentik.providers.oauth2.id_token import IDToken | ||||||
| from authentik.providers.oauth2.models import AccessToken, OAuth2Provider, ScopeMapping | from authentik.providers.oauth2.models import AccessToken, OAuth2Provider, ScopeMapping | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
|  |  | ||||||
| class OAuth2ProviderSerializer(ProviderSerializer): | class OAuth2ProviderSerializer(ProviderSerializer): | ||||||
|  | |||||||
| @ -36,8 +36,21 @@ class TestAuthorize(OAuthTestCase): | |||||||
|  |  | ||||||
|     def test_invalid_grant_type(self): |     def test_invalid_grant_type(self): | ||||||
|         """Test with invalid grant type""" |         """Test with invalid grant type""" | ||||||
|  |         OAuth2Provider.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             client_id="test", | ||||||
|  |             authorization_flow=create_test_flow(), | ||||||
|  |             redirect_uris="http://local.invalid/Foo", | ||||||
|  |         ) | ||||||
|         with self.assertRaises(AuthorizeError): |         with self.assertRaises(AuthorizeError): | ||||||
|             request = self.factory.get("/", data={"response_type": "invalid"}) |             request = self.factory.get( | ||||||
|  |                 "/", | ||||||
|  |                 data={ | ||||||
|  |                     "response_type": "invalid", | ||||||
|  |                     "client_id": "test", | ||||||
|  |                     "redirect_uri": "http://local.invalid/Foo", | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|             OAuthAuthorizationParams.from_request(request) |             OAuthAuthorizationParams.from_request(request) | ||||||
|  |  | ||||||
|     def test_invalid_client_id(self): |     def test_invalid_client_id(self): | ||||||
| @ -344,7 +357,12 @@ class TestAuthorize(OAuthTestCase): | |||||||
|                 ] |                 ] | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|         Application.objects.create(name="app", slug="app", provider=provider) |         provider.property_mappings.add( | ||||||
|  |             ScopeMapping.objects.create( | ||||||
|  |                 name=generate_id(), scope_name="test", expression="""return {"sub": "foo"}""" | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider) | ||||||
|         state = generate_id() |         state = generate_id() | ||||||
|         user = create_test_admin_user() |         user = create_test_admin_user() | ||||||
|         self.client.force_login(user) |         self.client.force_login(user) | ||||||
| @ -365,7 +383,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|                     "response_type": "id_token", |                     "response_type": "id_token", | ||||||
|                     "client_id": "test", |                     "client_id": "test", | ||||||
|                     "state": state, |                     "state": state, | ||||||
|                     "scope": "openid", |                     "scope": "openid test", | ||||||
|                     "redirect_uri": "http://localhost", |                     "redirect_uri": "http://localhost", | ||||||
|                     "nonce": generate_id(), |                     "nonce": generate_id(), | ||||||
|                 }, |                 }, | ||||||
| @ -390,6 +408,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             ) |             ) | ||||||
|             jwt = self.validate_jwt(token, provider) |             jwt = self.validate_jwt(token, provider) | ||||||
|             self.assertEqual(jwt["amr"], ["pwd"]) |             self.assertEqual(jwt["amr"], ["pwd"]) | ||||||
|  |             self.assertEqual(jwt["sub"], "foo") | ||||||
|             self.assertAlmostEqual( |             self.assertAlmostEqual( | ||||||
|                 jwt["exp"] - now().timestamp(), |                 jwt["exp"] - now().timestamp(), | ||||||
|                 expires, |                 expires, | ||||||
|  | |||||||
| @ -4,9 +4,10 @@ from urllib.parse import urlencode | |||||||
|  |  | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
|  |  | ||||||
| from authentik.core.models import Application | from authentik.core.models import Application, Group | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow | from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
|  | from authentik.policies.models import PolicyBinding | ||||||
| from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | ||||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||||
| from authentik.providers.oauth2.views.device_init import QS_KEY_CODE | from authentik.providers.oauth2.views.device_init import QS_KEY_CODE | ||||||
| @ -77,3 +78,23 @@ class TesOAuth2DeviceInit(OAuthTestCase): | |||||||
|             + "?" |             + "?" | ||||||
|             + urlencode({QS_KEY_CODE: token.user_code}), |             + urlencode({QS_KEY_CODE: token.user_code}), | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_device_init_denied(self): | ||||||
|  |         """Test device init""" | ||||||
|  |         group = Group.objects.create(name="foo") | ||||||
|  |         PolicyBinding.objects.create( | ||||||
|  |             group=group, | ||||||
|  |             target=self.application, | ||||||
|  |             order=0, | ||||||
|  |         ) | ||||||
|  |         token = DeviceToken.objects.create( | ||||||
|  |             user_code="foo", | ||||||
|  |             provider=self.provider, | ||||||
|  |         ) | ||||||
|  |         res = self.client.get( | ||||||
|  |             reverse("authentik_providers_oauth2_root:device-login") | ||||||
|  |             + "?" | ||||||
|  |             + urlencode({QS_KEY_CODE: token.user_code}) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(res.status_code, 200) | ||||||
|  |         self.assertIn(b"Permission denied", res.content) | ||||||
|  | |||||||
| @ -121,44 +121,18 @@ class OAuthAuthorizationParams: | |||||||
|         redirect_uri = query_dict.get("redirect_uri", "") |         redirect_uri = query_dict.get("redirect_uri", "") | ||||||
|  |  | ||||||
|         response_type = query_dict.get("response_type", "") |         response_type = query_dict.get("response_type", "") | ||||||
|         grant_type = None |  | ||||||
|         # Determine which flow to use. |  | ||||||
|         if response_type in [ResponseTypes.CODE]: |  | ||||||
|             grant_type = GrantTypes.AUTHORIZATION_CODE |  | ||||||
|         elif response_type in [ |  | ||||||
|             ResponseTypes.ID_TOKEN, |  | ||||||
|             ResponseTypes.ID_TOKEN_TOKEN, |  | ||||||
|         ]: |  | ||||||
|             grant_type = GrantTypes.IMPLICIT |  | ||||||
|         elif response_type in [ |  | ||||||
|             ResponseTypes.CODE_TOKEN, |  | ||||||
|             ResponseTypes.CODE_ID_TOKEN, |  | ||||||
|             ResponseTypes.CODE_ID_TOKEN_TOKEN, |  | ||||||
|         ]: |  | ||||||
|             grant_type = GrantTypes.HYBRID |  | ||||||
|  |  | ||||||
|         # Grant type validation. |  | ||||||
|         if not grant_type: |  | ||||||
|             LOGGER.warning("Invalid response type", type=response_type) |  | ||||||
|             raise AuthorizeError(redirect_uri, "unsupported_response_type", "", state) |  | ||||||
|  |  | ||||||
|         # Validate and check the response_mode against the predefined dict |         # Validate and check the response_mode against the predefined dict | ||||||
|         # Set to Query or Fragment if not defined in request |         # Set to Query or Fragment if not defined in request | ||||||
|         response_mode = query_dict.get("response_mode", False) |         response_mode = query_dict.get("response_mode", False) | ||||||
|  |  | ||||||
|         if response_mode not in ResponseMode.values: |  | ||||||
|             response_mode = ResponseMode.QUERY |  | ||||||
|  |  | ||||||
|             if grant_type in [GrantTypes.IMPLICIT, GrantTypes.HYBRID]: |  | ||||||
|                 response_mode = ResponseMode.FRAGMENT |  | ||||||
|  |  | ||||||
|         max_age = query_dict.get("max_age") |         max_age = query_dict.get("max_age") | ||||||
|         return OAuthAuthorizationParams( |         return OAuthAuthorizationParams( | ||||||
|             client_id=query_dict.get("client_id", ""), |             client_id=query_dict.get("client_id", ""), | ||||||
|             redirect_uri=redirect_uri, |             redirect_uri=redirect_uri, | ||||||
|             response_type=response_type, |             response_type=response_type, | ||||||
|             response_mode=response_mode, |             response_mode=response_mode, | ||||||
|             grant_type=grant_type, |             grant_type="", | ||||||
|             scope=set(query_dict.get("scope", "").split()), |             scope=set(query_dict.get("scope", "").split()), | ||||||
|             state=state, |             state=state, | ||||||
|             nonce=query_dict.get("nonce"), |             nonce=query_dict.get("nonce"), | ||||||
| @ -178,6 +152,7 @@ class OAuthAuthorizationParams: | |||||||
|             LOGGER.warning("Invalid client identifier", client_id=self.client_id) |             LOGGER.warning("Invalid client identifier", client_id=self.client_id) | ||||||
|             raise ClientIdError(client_id=self.client_id) |             raise ClientIdError(client_id=self.client_id) | ||||||
|         self.check_redirect_uri() |         self.check_redirect_uri() | ||||||
|  |         self.check_grant() | ||||||
|         self.check_scope(github_compat) |         self.check_scope(github_compat) | ||||||
|         self.check_nonce() |         self.check_nonce() | ||||||
|         self.check_code_challenge() |         self.check_code_challenge() | ||||||
| @ -186,6 +161,34 @@ class OAuthAuthorizationParams: | |||||||
|                 self.redirect_uri, "request_not_supported", self.grant_type, self.state |                 self.redirect_uri, "request_not_supported", self.grant_type, self.state | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |     def check_grant(self): | ||||||
|  |         """Check grant""" | ||||||
|  |         # Determine which flow to use. | ||||||
|  |         if self.response_type in [ResponseTypes.CODE]: | ||||||
|  |             self.grant_type = GrantTypes.AUTHORIZATION_CODE | ||||||
|  |         elif self.response_type in [ | ||||||
|  |             ResponseTypes.ID_TOKEN, | ||||||
|  |             ResponseTypes.ID_TOKEN_TOKEN, | ||||||
|  |         ]: | ||||||
|  |             self.grant_type = GrantTypes.IMPLICIT | ||||||
|  |         elif self.response_type in [ | ||||||
|  |             ResponseTypes.CODE_TOKEN, | ||||||
|  |             ResponseTypes.CODE_ID_TOKEN, | ||||||
|  |             ResponseTypes.CODE_ID_TOKEN_TOKEN, | ||||||
|  |         ]: | ||||||
|  |             self.grant_type = GrantTypes.HYBRID | ||||||
|  |  | ||||||
|  |         # Grant type validation. | ||||||
|  |         if not self.grant_type: | ||||||
|  |             LOGGER.warning("Invalid response type", type=self.response_type) | ||||||
|  |             raise AuthorizeError(self.redirect_uri, "unsupported_response_type", "", self.state) | ||||||
|  |  | ||||||
|  |         if self.response_mode not in ResponseMode.values: | ||||||
|  |             self.response_mode = ResponseMode.QUERY | ||||||
|  |  | ||||||
|  |             if self.grant_type in [GrantTypes.IMPLICIT, GrantTypes.HYBRID]: | ||||||
|  |                 self.response_mode = ResponseMode.FRAGMENT | ||||||
|  |  | ||||||
|     def check_redirect_uri(self): |     def check_redirect_uri(self): | ||||||
|         """Redirect URI validation.""" |         """Redirect URI validation.""" | ||||||
|         allowed_redirect_urls = self.provider.redirect_uris.split() |         allowed_redirect_urls = self.provider.redirect_uris.split() | ||||||
| @ -257,9 +260,9 @@ class OAuthAuthorizationParams: | |||||||
|         if SCOPE_OFFLINE_ACCESS in self.scope: |         if SCOPE_OFFLINE_ACCESS in self.scope: | ||||||
|             # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess |             # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess | ||||||
|             if PROMPT_CONSENT not in self.prompt: |             if PROMPT_CONSENT not in self.prompt: | ||||||
|                 raise AuthorizeError( |                 # Instead of ignoring the `offline_access` scope when `prompt` | ||||||
|                     self.redirect_uri, "consent_required", self.grant_type, self.state |                 # isn't set to `consent`, we set override it ourselves | ||||||
|                 ) |                 self.prompt.add(PROMPT_CONSENT) | ||||||
|             if self.response_type not in [ |             if self.response_type not in [ | ||||||
|                 ResponseTypes.CODE, |                 ResponseTypes.CODE, | ||||||
|                 ResponseTypes.CODE_TOKEN, |                 ResponseTypes.CODE_TOKEN, | ||||||
|  | |||||||
| @ -12,10 +12,11 @@ from django.views.decorators.csrf import csrf_exempt | |||||||
| from rest_framework.throttling import AnonRateThrottle | from rest_framework.throttling import AnonRateThrottle | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
|  | from authentik.core.models import Application | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.lib.utils.time import timedelta_from_string | from authentik.lib.utils.time import timedelta_from_string | ||||||
| from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | ||||||
| from authentik.providers.oauth2.views.device_init import QS_KEY_CODE, get_application | from authentik.providers.oauth2.views.device_init import QS_KEY_CODE | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| @ -38,7 +39,9 @@ class DeviceView(View): | |||||||
|         ).first() |         ).first() | ||||||
|         if not provider: |         if not provider: | ||||||
|             return HttpResponseBadRequest() |             return HttpResponseBadRequest() | ||||||
|         if not get_application(provider): |         try: | ||||||
|  |             _ = provider.application | ||||||
|  |         except Application.DoesNotExist: | ||||||
|             return HttpResponseBadRequest() |             return HttpResponseBadRequest() | ||||||
|         self.provider = provider |         self.provider = provider | ||||||
|         self.client_id = client_id |         self.client_id = client_id | ||||||
|  | |||||||
| @ -1,11 +1,10 @@ | |||||||
| """Device flow views""" | """Device flow views""" | ||||||
|  |  | ||||||
| from typing import Optional | from typing import Any, Optional | ||||||
|  |  | ||||||
| from django.http import HttpRequest, HttpResponse | from django.http import HttpRequest, HttpResponse | ||||||
| from django.utils.translation import gettext as _ | from django.utils.translation import gettext as _ | ||||||
| from django.views import View | from rest_framework.exceptions import ValidationError | ||||||
| from rest_framework.exceptions import ErrorDetail |  | ||||||
| from rest_framework.fields import CharField, IntegerField | from rest_framework.fields import CharField, IntegerField | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| @ -18,6 +17,7 @@ from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO, | |||||||
| from authentik.flows.stage import ChallengeStageView | from authentik.flows.stage import ChallengeStageView | ||||||
| from authentik.flows.views.executor import SESSION_KEY_PLAN | from authentik.flows.views.executor import SESSION_KEY_PLAN | ||||||
| from authentik.lib.utils.urls import redirect_with_qs | from authentik.lib.utils.urls import redirect_with_qs | ||||||
|  | from authentik.policies.views import PolicyAccessView | ||||||
| from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | ||||||
| from authentik.providers.oauth2.views.device_finish import ( | from authentik.providers.oauth2.views.device_finish import ( | ||||||
|     PLAN_CONTEXT_DEVICE, |     PLAN_CONTEXT_DEVICE, | ||||||
| @ -44,48 +44,52 @@ def get_application(provider: OAuth2Provider) -> Optional[Application]: | |||||||
|         return None |         return None | ||||||
|  |  | ||||||
|  |  | ||||||
| def validate_code(code: int, request: HttpRequest) -> Optional[HttpResponse]: | class CodeValidatorView(PolicyAccessView): | ||||||
|     """Validate user token""" |     """Helper to validate frontside token""" | ||||||
|     token = DeviceToken.objects.filter( |  | ||||||
|         user_code=code, |  | ||||||
|     ).first() |  | ||||||
|     if not token: |  | ||||||
|         return None |  | ||||||
|  |  | ||||||
|     app = get_application(token.provider) |     def __init__(self, code: str, **kwargs: Any) -> None: | ||||||
|     if not app: |         super().__init__(**kwargs) | ||||||
|         return None |         self.code = code | ||||||
|  |  | ||||||
|     scope_descriptions = UserInfoView().get_scope_descriptions(token.scope, token.provider) |     def resolve_provider_application(self): | ||||||
|     planner = FlowPlanner(token.provider.authorization_flow) |         self.token = DeviceToken.objects.filter(user_code=self.code).first() | ||||||
|     planner.allow_empty_flows = True |         if not self.token: | ||||||
|     try: |             raise Application.DoesNotExist | ||||||
|         plan = planner.plan( |         self.provider = self.token.provider | ||||||
|             request, |         self.application = self.token.provider.application | ||||||
|             { |  | ||||||
|                 PLAN_CONTEXT_SSO: True, |     def get(self, request: HttpRequest, *args, **kwargs): | ||||||
|                 PLAN_CONTEXT_APPLICATION: app, |         scope_descriptions = UserInfoView().get_scope_descriptions(self.token.scope, self.provider) | ||||||
|                 # OAuth2 related params |         planner = FlowPlanner(self.provider.authorization_flow) | ||||||
|                 PLAN_CONTEXT_DEVICE: token, |         planner.allow_empty_flows = True | ||||||
|                 # Consent related params |         planner.use_cache = False | ||||||
|                 PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.") |         try: | ||||||
|                 % {"application": app.name}, |             plan = planner.plan( | ||||||
|                 PLAN_CONTEXT_CONSENT_PERMISSIONS: scope_descriptions, |                 request, | ||||||
|             }, |                 { | ||||||
|  |                     PLAN_CONTEXT_SSO: True, | ||||||
|  |                     PLAN_CONTEXT_APPLICATION: self.application, | ||||||
|  |                     # OAuth2 related params | ||||||
|  |                     PLAN_CONTEXT_DEVICE: self.token, | ||||||
|  |                     # Consent related params | ||||||
|  |                     PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.") | ||||||
|  |                     % {"application": self.application.name}, | ||||||
|  |                     PLAN_CONTEXT_CONSENT_PERMISSIONS: scope_descriptions, | ||||||
|  |                 }, | ||||||
|  |             ) | ||||||
|  |         except FlowNonApplicableException: | ||||||
|  |             LOGGER.warning("Flow not applicable to user") | ||||||
|  |             return None | ||||||
|  |         plan.insert_stage(in_memory_stage(OAuthDeviceCodeFinishStage)) | ||||||
|  |         request.session[SESSION_KEY_PLAN] = plan | ||||||
|  |         return redirect_with_qs( | ||||||
|  |             "authentik_core:if-flow", | ||||||
|  |             request.GET, | ||||||
|  |             flow_slug=self.token.provider.authorization_flow.slug, | ||||||
|         ) |         ) | ||||||
|     except FlowNonApplicableException: |  | ||||||
|         LOGGER.warning("Flow not applicable to user") |  | ||||||
|         return None |  | ||||||
|     plan.insert_stage(in_memory_stage(OAuthDeviceCodeFinishStage)) |  | ||||||
|     request.session[SESSION_KEY_PLAN] = plan |  | ||||||
|     return redirect_with_qs( |  | ||||||
|         "authentik_core:if-flow", |  | ||||||
|         request.GET, |  | ||||||
|         flow_slug=token.provider.authorization_flow.slug, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DeviceEntryView(View): | class DeviceEntryView(PolicyAccessView): | ||||||
|     """View used to initiate the device-code flow, url entered by endusers""" |     """View used to initiate the device-code flow, url entered by endusers""" | ||||||
|  |  | ||||||
|     def dispatch(self, request: HttpRequest) -> HttpResponse: |     def dispatch(self, request: HttpRequest) -> HttpResponse: | ||||||
| @ -95,7 +99,9 @@ class DeviceEntryView(View): | |||||||
|             LOGGER.info("Brand has no device code flow configured", brand=brand) |             LOGGER.info("Brand has no device code flow configured", brand=brand) | ||||||
|             return HttpResponse(status=404) |             return HttpResponse(status=404) | ||||||
|         if QS_KEY_CODE in request.GET: |         if QS_KEY_CODE in request.GET: | ||||||
|             validation = validate_code(request.GET[QS_KEY_CODE], request) |             validation = CodeValidatorView(request.GET[QS_KEY_CODE], request=request).dispatch( | ||||||
|  |                 request | ||||||
|  |             ) | ||||||
|             if validation: |             if validation: | ||||||
|                 return validation |                 return validation | ||||||
|             LOGGER.info("Got code from query parameter but no matching token found") |             LOGGER.info("Got code from query parameter but no matching token found") | ||||||
| @ -130,6 +136,13 @@ class OAuthDeviceCodeChallengeResponse(ChallengeResponse): | |||||||
|     code = IntegerField() |     code = IntegerField() | ||||||
|     component = CharField(default="ak-provider-oauth2-device-code") |     component = CharField(default="ak-provider-oauth2-device-code") | ||||||
|  |  | ||||||
|  |     def validate_code(self, code: int) -> HttpResponse | None: | ||||||
|  |         """Validate code and save the returned http response""" | ||||||
|  |         response = CodeValidatorView(code, request=self.stage.request).dispatch(self.stage.request) | ||||||
|  |         if not response: | ||||||
|  |             raise ValidationError(_("Invalid code"), "invalid") | ||||||
|  |         return response | ||||||
|  |  | ||||||
|  |  | ||||||
| class OAuthDeviceCodeStage(ChallengeStageView): | class OAuthDeviceCodeStage(ChallengeStageView): | ||||||
|     """Flow challenge for users to enter device codes""" |     """Flow challenge for users to enter device codes""" | ||||||
| @ -145,12 +158,4 @@ class OAuthDeviceCodeStage(ChallengeStageView): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: |     def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: | ||||||
|         code = response.validated_data["code"] |         return response.validated_data["code"] | ||||||
|         validation = validate_code(code, self.request) |  | ||||||
|         if not validation: |  | ||||||
|             response._errors.setdefault("code", []) |  | ||||||
|             response._errors["code"].append(ErrorDetail(_("Invalid code"), "invalid")) |  | ||||||
|             return self.challenge_invalid(response) |  | ||||||
|         # Run cancel to cleanup the current flow |  | ||||||
|         self.executor.cancel() |  | ||||||
|         return validation |  | ||||||
|  | |||||||
| @ -101,8 +101,8 @@ class UserInfoView(View): | |||||||
|                     value=value, |                     value=value, | ||||||
|                 ) |                 ) | ||||||
|                 continue |                 continue | ||||||
|             LOGGER.debug("updated scope", scope=scope) |  | ||||||
|             always_merger.merge(final_claims, value) |             always_merger.merge(final_claims, value) | ||||||
|  |             LOGGER.debug("updated scope", scope=scope) | ||||||
|         return final_claims |         return final_claims | ||||||
|  |  | ||||||
|     def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: |     def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: | ||||||
| @ -121,8 +121,9 @@ class UserInfoView(View): | |||||||
|         """Handle GET Requests for UserInfo""" |         """Handle GET Requests for UserInfo""" | ||||||
|         if not self.token: |         if not self.token: | ||||||
|             return HttpResponseBadRequest() |             return HttpResponseBadRequest() | ||||||
|         claims = self.get_claims(self.token.provider, self.token) |         claims = {} | ||||||
|         claims["sub"] = self.token.id_token.sub |         claims.setdefault("sub", self.token.id_token.sub) | ||||||
|  |         claims.update(self.get_claims(self.token.provider, self.token)) | ||||||
|         if self.token.id_token.nonce: |         if self.token.id_token.nonce: | ||||||
|             claims["nonce"] = self.token.id_token.nonce |             claims["nonce"] = self.token.id_token.nonce | ||||||
|         response = TokenResponse(claims) |         response = TokenResponse(claims) | ||||||
|  | |||||||
| @ -10,7 +10,3 @@ class AuthentikProviderProxyConfig(ManagedAppConfig): | |||||||
|     label = "authentik_providers_proxy" |     label = "authentik_providers_proxy" | ||||||
|     verbose_name = "authentik Providers.Proxy" |     verbose_name = "authentik Providers.Proxy" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_load_providers_proxy_signals(self): |  | ||||||
|         """Load proxy signals""" |  | ||||||
|         self.import_module("authentik.providers.proxy.signals") |  | ||||||
|  | |||||||
| @ -22,7 +22,6 @@ from rest_framework.serializers import PrimaryKeyRelatedField, ValidationError | |||||||
| from rest_framework.viewsets import ModelViewSet | from rest_framework.viewsets import ModelViewSet | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.core.api.providers import ProviderSerializer | from authentik.core.api.providers import ProviderSerializer | ||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer | from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer | ||||||
| @ -33,6 +32,7 @@ from authentik.providers.saml.processors.assertion import AssertionProcessor | |||||||
| from authentik.providers.saml.processors.authn_request_parser import AuthNRequest | from authentik.providers.saml.processors.authn_request_parser import AuthNRequest | ||||||
| from authentik.providers.saml.processors.metadata import MetadataProcessor | from authentik.providers.saml.processors.metadata import MetadataProcessor | ||||||
| from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser | from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
| from authentik.sources.saml.processors.constants import SAML_BINDING_POST, SAML_BINDING_REDIRECT | from authentik.sources.saml.processors.constants import SAML_BINDING_POST, SAML_BINDING_REDIRECT | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  | |||||||
| @ -10,7 +10,3 @@ class AuthentikProviderSCIMConfig(ManagedAppConfig): | |||||||
|     label = "authentik_providers_scim" |     label = "authentik_providers_scim" | ||||||
|     verbose_name = "authentik Providers.SCIM" |     verbose_name = "authentik Providers.SCIM" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_load_signals(self): |  | ||||||
|         """Load signals""" |  | ||||||
|         self.import_module("authentik.providers.scim.signals") |  | ||||||
|  | |||||||
| @ -15,10 +15,10 @@ from rest_framework.response import Response | |||||||
| from rest_framework.serializers import ModelSerializer | from rest_framework.serializers import ModelSerializer | ||||||
| from rest_framework.viewsets import GenericViewSet | from rest_framework.viewsets import GenericViewSet | ||||||
|  |  | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.core.api.utils import PassiveSerializer | from authentik.core.api.utils import PassiveSerializer | ||||||
| from authentik.policies.event_matcher.models import model_choices | from authentik.policies.event_matcher.models import model_choices | ||||||
| from authentik.rbac.api.rbac import PermissionAssignSerializer | from authentik.rbac.api.rbac import PermissionAssignSerializer | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
| from authentik.rbac.models import Role | from authentik.rbac.models import Role | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -16,11 +16,11 @@ from rest_framework.response import Response | |||||||
| from rest_framework.serializers import ModelSerializer | from rest_framework.serializers import ModelSerializer | ||||||
| from rest_framework.viewsets import GenericViewSet | from rest_framework.viewsets import GenericViewSet | ||||||
|  |  | ||||||
| from authentik.api.decorators import permission_required |  | ||||||
| from authentik.core.api.groups import GroupMemberSerializer | from authentik.core.api.groups import GroupMemberSerializer | ||||||
| from authentik.core.models import User, UserTypes | from authentik.core.models import User, UserTypes | ||||||
| from authentik.policies.event_matcher.models import model_choices | from authentik.policies.event_matcher.models import model_choices | ||||||
| from authentik.rbac.api.rbac import PermissionAssignSerializer | from authentik.rbac.api.rbac import PermissionAssignSerializer | ||||||
|  | from authentik.rbac.decorators import permission_required | ||||||
|  |  | ||||||
|  |  | ||||||
| class UserObjectPermissionSerializer(ModelSerializer): | class UserObjectPermissionSerializer(ModelSerializer): | ||||||
|  | |||||||
| @ -10,7 +10,3 @@ class AuthentikRBACConfig(ManagedAppConfig): | |||||||
|     label = "authentik_rbac" |     label = "authentik_rbac" | ||||||
|     verbose_name = "authentik RBAC" |     verbose_name = "authentik RBAC" | ||||||
|     default = True |     default = True | ||||||
|  |  | ||||||
|     def reconcile_global_load_rbac_signals(self): |  | ||||||
|         """Load rbac signals""" |  | ||||||
|         self.import_module("authentik.rbac.signals") |  | ||||||
|  | |||||||
| @ -14,18 +14,23 @@ LOGGER = get_logger() | |||||||
| def permission_required(obj_perm: Optional[str] = None, global_perms: Optional[list[str]] = None): | def permission_required(obj_perm: Optional[str] = None, global_perms: Optional[list[str]] = None): | ||||||
|     """Check permissions for a single custom action""" |     """Check permissions for a single custom action""" | ||||||
| 
 | 
 | ||||||
|     def wrapper_outter(func: Callable): |     def _check_obj_perm(self: ModelViewSet, request: Request): | ||||||
|  |         # Check obj_perm both globally and on the specific object | ||||||
|  |         # Having the global permission has higher priority | ||||||
|  |         if request.user.has_perm(obj_perm): | ||||||
|  |             return | ||||||
|  |         obj = self.get_object() | ||||||
|  |         if not request.user.has_perm(obj_perm, obj): | ||||||
|  |             LOGGER.debug("denying access for object", user=request.user, perm=obj_perm, obj=obj) | ||||||
|  |             self.permission_denied(request) | ||||||
|  | 
 | ||||||
|  |     def wrapper_outer(func: Callable): | ||||||
|         """Check permissions for a single custom action""" |         """Check permissions for a single custom action""" | ||||||
| 
 | 
 | ||||||
|         @wraps(func) |         @wraps(func) | ||||||
|         def wrapper(self: ModelViewSet, request: Request, *args, **kwargs) -> Response: |         def wrapper(self: ModelViewSet, request: Request, *args, **kwargs) -> Response: | ||||||
|             if obj_perm: |             if obj_perm: | ||||||
|                 obj = self.get_object() |                 _check_obj_perm(self, request) | ||||||
|                 if not request.user.has_perm(obj_perm, obj): |  | ||||||
|                     LOGGER.debug( |  | ||||||
|                         "denying access for object", user=request.user, perm=obj_perm, obj=obj |  | ||||||
|                     ) |  | ||||||
|                     return self.permission_denied(request) |  | ||||||
|             if global_perms: |             if global_perms: | ||||||
|                 for other_perm in global_perms: |                 for other_perm in global_perms: | ||||||
|                     if not request.user.has_perm(other_perm): |                     if not request.user.has_perm(other_perm): | ||||||
| @ -35,4 +40,4 @@ def permission_required(obj_perm: Optional[str] = None, global_perms: Optional[l | |||||||
| 
 | 
 | ||||||
|         return wrapper |         return wrapper | ||||||
| 
 | 
 | ||||||
|     return wrapper_outter |     return wrapper_outer | ||||||
							
								
								
									
										58
									
								
								authentik/rbac/tests/test_decorators.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								authentik/rbac/tests/test_decorators.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,58 @@ | |||||||
|  | """test decorators api""" | ||||||
|  |  | ||||||
|  | from django.urls import reverse | ||||||
|  | from guardian.shortcuts import assign_perm | ||||||
|  | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
|  | from authentik.core.models import Application | ||||||
|  | from authentik.core.tests.utils import create_test_user | ||||||
|  | from authentik.lib.generators import generate_id | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestAPIDecorators(APITestCase): | ||||||
|  |     """test decorators api""" | ||||||
|  |  | ||||||
|  |     def setUp(self) -> None: | ||||||
|  |         super().setUp() | ||||||
|  |         self.user = create_test_user() | ||||||
|  |  | ||||||
|  |     def test_obj_perm_denied(self): | ||||||
|  |         """Test object perm denied""" | ||||||
|  |         self.client.force_login(self.user) | ||||||
|  |         app = Application.objects.create(name=generate_id(), slug=generate_id()) | ||||||
|  |         response = self.client.get( | ||||||
|  |             reverse("authentik_api:application-metrics", kwargs={"slug": app.slug}) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 403) | ||||||
|  |  | ||||||
|  |     def test_obj_perm_global(self): | ||||||
|  |         """Test object perm successful (global)""" | ||||||
|  |         assign_perm("authentik_core.view_application", self.user) | ||||||
|  |         assign_perm("authentik_events.view_event", self.user) | ||||||
|  |         self.client.force_login(self.user) | ||||||
|  |         app = Application.objects.create(name=generate_id(), slug=generate_id()) | ||||||
|  |         response = self.client.get( | ||||||
|  |             reverse("authentik_api:application-metrics", kwargs={"slug": app.slug}) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 200) | ||||||
|  |  | ||||||
|  |     def test_obj_perm_scoped(self): | ||||||
|  |         """Test object perm successful (scoped)""" | ||||||
|  |         assign_perm("authentik_events.view_event", self.user) | ||||||
|  |         app = Application.objects.create(name=generate_id(), slug=generate_id()) | ||||||
|  |         assign_perm("authentik_core.view_application", self.user, app) | ||||||
|  |         self.client.force_login(self.user) | ||||||
|  |         response = self.client.get( | ||||||
|  |             reverse("authentik_api:application-metrics", kwargs={"slug": app.slug}) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 200) | ||||||
|  |  | ||||||
|  |     def test_other_perm_denied(self): | ||||||
|  |         """Test other perm denied""" | ||||||
|  |         self.client.force_login(self.user) | ||||||
|  |         app = Application.objects.create(name=generate_id(), slug=generate_id()) | ||||||
|  |         assign_perm("authentik_core.view_application", self.user, app) | ||||||
|  |         response = self.client.get( | ||||||
|  |             reverse("authentik_api:application-metrics", kwargs={"slug": app.slug}) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 403) | ||||||
| @ -91,13 +91,10 @@ def _get_startup_tasks_default_tenant() -> list[Callable]: | |||||||
| def _get_startup_tasks_all_tenants() -> list[Callable]: | def _get_startup_tasks_all_tenants() -> list[Callable]: | ||||||
|     """Get all tasks to be run on startup for all tenants""" |     """Get all tasks to be run on startup for all tenants""" | ||||||
|     from authentik.admin.tasks import clear_update_notifications |     from authentik.admin.tasks import clear_update_notifications | ||||||
|     from authentik.outposts.tasks import outpost_connection_discovery, outpost_controller_all |  | ||||||
|     from authentik.providers.proxy.tasks import proxy_set_defaults |     from authentik.providers.proxy.tasks import proxy_set_defaults | ||||||
|  |  | ||||||
|     return [ |     return [ | ||||||
|         clear_update_notifications, |         clear_update_notifications, | ||||||
|         outpost_connection_discovery, |  | ||||||
|         outpost_controller_all, |  | ||||||
|         proxy_set_defaults, |         proxy_set_defaults, | ||||||
|     ] |     ] | ||||||
|  |  | ||||||
|  | |||||||
| @ -7,6 +7,8 @@ from psycopg import connect | |||||||
|  |  | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
|  |  | ||||||
|  | QUERY = """SELECT id FROM public.authentik_install_id ORDER BY id LIMIT 1;""" | ||||||
|  |  | ||||||
|  |  | ||||||
| @lru_cache | @lru_cache | ||||||
| def get_install_id() -> str: | def get_install_id() -> str: | ||||||
| @ -18,7 +20,7 @@ def get_install_id() -> str: | |||||||
|     if settings.TEST: |     if settings.TEST: | ||||||
|         return str(uuid4()) |         return str(uuid4()) | ||||||
|     with connection.cursor() as cursor: |     with connection.cursor() as cursor: | ||||||
|         cursor.execute("SELECT id FROM public.authentik_install_id LIMIT 1;") |         cursor.execute(QUERY) | ||||||
|         return cursor.fetchone()[0] |         return cursor.fetchone()[0] | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -38,5 +40,5 @@ def get_install_id_raw(): | |||||||
|         sslkey=CONFIG.get("postgresql.sslkey"), |         sslkey=CONFIG.get("postgresql.sslkey"), | ||||||
|     ) |     ) | ||||||
|     cursor = conn.cursor() |     cursor = conn.cursor() | ||||||
|     cursor.execute("SELECT id FROM public.authentik_install_id LIMIT 1;") |     cursor.execute(QUERY) | ||||||
|     return cursor.fetchone()[0] |     return cursor.fetchone()[0] | ||||||
|  | |||||||
| @ -1,8 +1,7 @@ | |||||||
| """Dynamically set SameSite depending if the upstream connection is TLS or not""" | """Dynamically set SameSite depending if the upstream connection is TLS or not""" | ||||||
|  |  | ||||||
| from hashlib import sha512 | from hashlib import sha512 | ||||||
| from time import time | from time import perf_counter, time | ||||||
| from timeit import default_timer |  | ||||||
| from typing import Any, Callable, Optional | from typing import Any, Callable, Optional | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| @ -294,14 +293,14 @@ class LoggingMiddleware: | |||||||
|         self.get_response = get_response |         self.get_response = get_response | ||||||
|  |  | ||||||
|     def __call__(self, request: HttpRequest) -> HttpResponse: |     def __call__(self, request: HttpRequest) -> HttpResponse: | ||||||
|         start = default_timer() |         start = perf_counter() | ||||||
|         response = self.get_response(request) |         response = self.get_response(request) | ||||||
|         status_code = response.status_code |         status_code = response.status_code | ||||||
|         kwargs = { |         kwargs = { | ||||||
|             "request_id": getattr(request, "request_id", None), |             "request_id": getattr(request, "request_id", None), | ||||||
|         } |         } | ||||||
|         kwargs.update(getattr(response, "ak_context", {})) |         kwargs.update(getattr(response, "ak_context", {})) | ||||||
|         self.log(request, status_code, int((default_timer() - start) * 1000), **kwargs) |         self.log(request, status_code, int((perf_counter() - start) * 1000), **kwargs) | ||||||
|         return response |         return response | ||||||
|  |  | ||||||
|     def log(self, request: HttpRequest, status_code: int, runtime: int, **kwargs): |     def log(self, request: HttpRequest, status_code: int, runtime: int, **kwargs): | ||||||
|  | |||||||
| @ -69,7 +69,6 @@ TENANT_APPS = [ | |||||||
|     "authentik.admin", |     "authentik.admin", | ||||||
|     "authentik.api", |     "authentik.api", | ||||||
|     "authentik.crypto", |     "authentik.crypto", | ||||||
|     "authentik.events", |  | ||||||
|     "authentik.flows", |     "authentik.flows", | ||||||
|     "authentik.outposts", |     "authentik.outposts", | ||||||
|     "authentik.policies.dummy", |     "authentik.policies.dummy", | ||||||
| @ -482,13 +481,6 @@ def _update_settings(app_path: str): | |||||||
|         pass |         pass | ||||||
|  |  | ||||||
|  |  | ||||||
| # Load subapps's settings |  | ||||||
| for _app in set(SHARED_APPS + TENANT_APPS): |  | ||||||
|     if not _app.startswith("authentik"): |  | ||||||
|         continue |  | ||||||
|     _update_settings(f"{_app}.settings") |  | ||||||
| _update_settings("data.user_settings") |  | ||||||
|  |  | ||||||
| if DEBUG: | if DEBUG: | ||||||
|     CELERY["task_always_eager"] = True |     CELERY["task_always_eager"] = True | ||||||
|     os.environ[ENV_GIT_HASH_KEY] = "dev" |     os.environ[ENV_GIT_HASH_KEY] = "dev" | ||||||
| @ -509,5 +501,17 @@ try: | |||||||
| except ImportError: | except ImportError: | ||||||
|     pass |     pass | ||||||
|  |  | ||||||
|  | # Import events after other apps since it relies on tasks and other things from all apps | ||||||
|  | # being imported for @prefill_task | ||||||
|  | TENANT_APPS.append("authentik.events") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Load subapps's settings | ||||||
|  | for _app in set(SHARED_APPS + TENANT_APPS): | ||||||
|  |     if not _app.startswith("authentik"): | ||||||
|  |         continue | ||||||
|  |     _update_settings(f"{_app}.settings") | ||||||
|  | _update_settings("data.user_settings") | ||||||
|  |  | ||||||
| SHARED_APPS = list(OrderedDict.fromkeys(SHARED_APPS + TENANT_APPS)) | SHARED_APPS = list(OrderedDict.fromkeys(SHARED_APPS + TENANT_APPS)) | ||||||
| INSTALLED_APPS = list(OrderedDict.fromkeys(SHARED_APPS + TENANT_APPS)) | INSTALLED_APPS = list(OrderedDict.fromkeys(SHARED_APPS + TENANT_APPS)) | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	