Compare commits
	
		
			153 Commits
		
	
	
		
			smusali/ev
			...
			version/20
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 1c03cfa906 | |||
| e2dbab5bca | |||
| 3a6c42fefb | |||
| 6bb180f94e | |||
| 03dea17519 | |||
| 49d83f11bd | |||
| 5f0af81e4d | |||
| 63591e1710 | |||
| 6503a7b048 | |||
| 7e244e0679 | |||
| c1998bf3f2 | |||
| 83372618a8 | |||
| 89a876e141 | |||
| 26d6e8bc5c | |||
| d9dc373170 | |||
| 4ec37c5239 | |||
| a9cfa6fe35 | |||
| 5ac5084149 | |||
| eda38a30b1 | |||
| 9b84bf7174 | |||
| f74549be6d | |||
| 76f4d7fb0a | |||
| d1cf1dd083 | |||
| 2835fbd390 | |||
| 76ad2c8925 | |||
| 2270629fdc | |||
| 43a629efc1 | |||
| 4044e52403 | |||
| aa7c846467 | |||
| 8ab7f4073b | |||
| a05856c2ef | |||
| 9e9154e04a | |||
| 32549066c0 | |||
| 5ed3e879a2 | |||
| 4e4923ad0e | |||
| 0302d147e9 | |||
| 8256f1897d | |||
| 16d321835d | |||
| f34612efe6 | |||
| e82f147130 | |||
| 0ea6ad8eea | |||
| f731443220 | |||
| b70a66cde5 | |||
| b733dbbcb0 | |||
| e34d4c0669 | |||
| 310983a4d0 | |||
| 47b0fc86f7 | |||
| b6e961b1f3 | |||
| 874d7ff320 | |||
| e4a5bc9df6 | |||
| 318e0cf9f8 | |||
| bd0815d894 | |||
| af35ecfe66 | |||
| 0c05cd64bb | |||
| cb80b76490 | |||
| 061d4bc758 | |||
| 8ff27f69e1 | |||
| 045cd98276 | |||
| b520843984 | |||
| 92216e4ea8 | |||
| babaeb2d0c | |||
| 52b8f24b75 | |||
| 464addfc8d | |||
| 8df73c2f6f | |||
| 9ab3971e63 | |||
| 09888cb89f | |||
| 2abcc9ce8f | |||
| 5b0e92f034 | |||
| a3bfb3d25c | |||
| 2c1df6702c | |||
| b999e23d27 | |||
| e0db9f3ea1 | |||
| dcc3ca664a | |||
| 7d37e3f668 | |||
| e48f6bbec4 | |||
| d27caaabc3 | |||
| 0dee706a87 | |||
| 7d527beea8 | |||
| 4733778460 | |||
| c048f4a356 | |||
| 65e245c003 | |||
| 600d59ff58 | |||
| 703628f354 | |||
| 693de081ef | |||
| f367249bab | |||
| 2841db082c | |||
| ce24f974aa | |||
| 1f93e6fd3f | |||
| 7dfde9029f | |||
| f5d62b828b | |||
| 703eb682b7 | |||
| 5cae3192b1 | |||
| 83e143032d | |||
| e0e7cc24da | |||
| 8bc746d577 | |||
| a84f403e79 | |||
| e4f4482d2a | |||
| 844b4e96cd | |||
| f3b4e03243 | |||
| 4f5e2a438e | |||
| 32c980e29e | |||
| bd29392825 | |||
| 9756432876 | |||
| 8b2d1a9b21 | |||
| adbd97323c | |||
| 77a8b2d751 | |||
| 08c850938b | |||
| 7db598c04e | |||
| 1ef224f5fd | |||
| b01c48698d | |||
| 1546fa276a | |||
| f50bd74b46 | |||
| 414a5c36c8 | |||
| c4455b6915 | |||
| 9013caeab4 | |||
| 40a1e5a9b2 | |||
| 4dadcc1dfd | |||
| 0b8678f7ee | |||
| aa8dc94a97 | |||
| 20996e994e | |||
| db17f04830 | |||
| b99fca62d8 | |||
| 8818ce3306 | |||
| 25d3f2e06e | |||
| 1537682026 | |||
| ebd05be2c4 | |||
| c90792d876 | |||
| b92630804f | |||
| 1afd5ef95a | |||
| e5cc2c6d98 | |||
| 84fdd4d737 | |||
| 5fe2772567 | |||
| d2f9b66424 | |||
| c9b39f2eba | |||
| 2ecc2119fc | |||
| 49b7ebdc53 | |||
| 70f72c524d | |||
| 87e0ac743a | |||
| b5b8b0e9cd | |||
| d10b358767 | |||
| 0887fa8fde | |||
| 799dd48861 | |||
| 919d1f349f | |||
| a36b6e8315 | |||
| 69f9dfc9f6 | |||
| 27efe68f1c | |||
| e9223618ba | |||
| e9672a5285 | |||
| 7d724d9931 | |||
| edcc6b2031 | |||
| a71948c9b7 | |||
| a395e347df | |||
| f4b336a974 | 
| @ -1,12 +1,20 @@ | ||||
| [bumpversion] | ||||
| current_version = 2023.10.7 | ||||
| current_version = 2024.2.4 | ||||
| tag = True | ||||
| commit = True | ||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) | ||||
| serialize = {major}.{minor}.{patch} | ||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | ||||
| serialize =  | ||||
| 	{major}.{minor}.{patch}-{rc_t}{rc_n} | ||||
| 	{major}.{minor}.{patch} | ||||
| message = release: {new_version} | ||||
| tag_name = version/{new_version} | ||||
|  | ||||
| [bumpversion:part:rc_t] | ||||
| values =  | ||||
| 	rc | ||||
| 	final | ||||
| optional_value = final | ||||
|  | ||||
| [bumpversion:file:pyproject.toml] | ||||
|  | ||||
| [bumpversion:file:docker-compose.yml] | ||||
|  | ||||
| @ -9,9 +9,6 @@ inputs: | ||||
| runs: | ||||
|   using: "composite" | ||||
|   steps: | ||||
|     - name: Generate config | ||||
|       id: ev | ||||
|       uses: ./.github/actions/docker-push-variables | ||||
|     - name: Find Comment | ||||
|       uses: peter-evans/find-comment@v2 | ||||
|       id: fc | ||||
|  | ||||
							
								
								
									
										73
									
								
								.github/actions/docker-push-variables/action.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										73
									
								
								.github/actions/docker-push-variables/action.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,64 +1,47 @@ | ||||
| --- | ||||
| name: "Prepare docker environment variables" | ||||
| description: "Prepare docker environment variables" | ||||
|  | ||||
| inputs: | ||||
|   image-name: | ||||
|     required: true | ||||
|     description: "Docker image prefix" | ||||
|   image-arch: | ||||
|     required: false | ||||
|     description: "Docker image arch" | ||||
|  | ||||
| outputs: | ||||
|   shouldBuild: | ||||
|     description: "Whether to build image or not" | ||||
|     value: ${{ steps.ev.outputs.shouldBuild }} | ||||
|   branchName: | ||||
|     description: "Branch name" | ||||
|     value: ${{ steps.ev.outputs.branchName }} | ||||
|   branchNameContainer: | ||||
|     description: "Branch name (for containers)" | ||||
|     value: ${{ steps.ev.outputs.branchNameContainer }} | ||||
|   timestamp: | ||||
|     description: "Timestamp" | ||||
|     value: ${{ steps.ev.outputs.timestamp }} | ||||
|  | ||||
|   sha: | ||||
|     description: "sha" | ||||
|     value: ${{ steps.ev.outputs.sha }} | ||||
|   shortHash: | ||||
|     description: "shortHash" | ||||
|     value: ${{ steps.ev.outputs.shortHash }} | ||||
|  | ||||
|   version: | ||||
|     description: "version" | ||||
|     description: "Version" | ||||
|     value: ${{ steps.ev.outputs.version }} | ||||
|   versionFamily: | ||||
|     description: "versionFamily" | ||||
|     value: ${{ steps.ev.outputs.versionFamily }} | ||||
|   prerelease: | ||||
|     description: "Prerelease" | ||||
|     value: ${{ steps.ev.outputs.prerelease }} | ||||
|  | ||||
|   imageTags: | ||||
|     description: "Docker image tags" | ||||
|     value: ${{ steps.ev.outputs.imageTags }} | ||||
|   imageMainTag: | ||||
|     description: "Docker image main tag" | ||||
|     value: ${{ steps.ev.outputs.imageMainTag }} | ||||
|  | ||||
| runs: | ||||
|   using: "composite" | ||||
|   steps: | ||||
|     - name: Generate config | ||||
|       id: ev | ||||
|       shell: python | ||||
|       shell: bash | ||||
|       env: | ||||
|         IMAGE_NAME: ${{ inputs.image-name }} | ||||
|         IMAGE_ARCH: ${{ inputs.image-arch }} | ||||
|         PR_HEAD_SHA: ${{ github.event.pull_request.head.sha }} | ||||
|       run: | | ||||
|         """Helper script to get the actual branch name, docker safe""" | ||||
|         import configparser | ||||
|         import os | ||||
|         from time import time | ||||
|  | ||||
|         parser = configparser.ConfigParser() | ||||
|         parser.read(".bumpversion.cfg") | ||||
|  | ||||
|         branch_name = os.environ["GITHUB_REF"] | ||||
|         if os.environ.get("GITHUB_HEAD_REF", "") != "": | ||||
|             branch_name = os.environ["GITHUB_HEAD_REF"] | ||||
|  | ||||
|         should_build = str(os.environ.get("DOCKER_USERNAME", "") != "").lower() | ||||
|         version = parser.get("bumpversion", "current_version") | ||||
|         version_family = ".".join(version.split(".")[:-1]) | ||||
|         safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-") | ||||
|  | ||||
|         sha = os.environ["GITHUB_SHA"] if not "${{ github.event.pull_request.head.sha }}" else "${{ github.event.pull_request.head.sha }}" | ||||
|  | ||||
|         with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output: | ||||
|             print("branchName=%s" % branch_name, file=_output) | ||||
|             print("branchNameContainer=%s" % safe_branch_name, file=_output) | ||||
|             print("timestamp=%s" % int(time()), file=_output) | ||||
|             print("sha=%s" % sha, file=_output) | ||||
|             print("shortHash=%s" % sha[:7], file=_output) | ||||
|             print("shouldBuild=%s" % should_build, file=_output) | ||||
|             print("version=%s" % version, file=_output) | ||||
|             print("versionFamily=%s" % version_family, file=_output) | ||||
|         python3 ${{ github.action_path }}/push_vars.py | ||||
|  | ||||
							
								
								
									
										62
									
								
								.github/actions/docker-push-variables/push_vars.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								.github/actions/docker-push-variables/push_vars.py
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,62 @@ | ||||
| """Helper script to get the actual branch name, docker safe""" | ||||
|  | ||||
| import configparser | ||||
| import os | ||||
| from time import time | ||||
|  | ||||
| parser = configparser.ConfigParser() | ||||
| parser.read(".bumpversion.cfg") | ||||
|  | ||||
| should_build = str(os.environ.get("DOCKER_USERNAME", None) is not None).lower() | ||||
|  | ||||
| branch_name = os.environ["GITHUB_REF"] | ||||
| if os.environ.get("GITHUB_HEAD_REF", "") != "": | ||||
|     branch_name = os.environ["GITHUB_HEAD_REF"] | ||||
| safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-") | ||||
|  | ||||
| image_names = os.getenv("IMAGE_NAME").split(",") | ||||
| image_arch = os.getenv("IMAGE_ARCH") or None | ||||
|  | ||||
| is_pull_request = bool(os.getenv("PR_HEAD_SHA")) | ||||
| is_release = "dev" not in image_names[0] | ||||
|  | ||||
| sha = os.environ["GITHUB_SHA"] if not is_pull_request else os.getenv("PR_HEAD_SHA") | ||||
|  | ||||
| # 2042.1.0 or 2042.1.0-rc1 | ||||
| version = parser.get("bumpversion", "current_version") | ||||
| # 2042.1 | ||||
| version_family = ".".join(version.split("-", 1)[0].split(".")[:-1]) | ||||
| prerelease = "-" in version | ||||
|  | ||||
| image_tags = [] | ||||
| if is_release: | ||||
|     for name in image_names: | ||||
|         image_tags += [ | ||||
|             f"{name}:{version}", | ||||
|         ] | ||||
|         if not prerelease: | ||||
|             image_tags += [ | ||||
|                 f"{name}:latest", | ||||
|                 f"{name}:{version_family}", | ||||
|             ] | ||||
| else: | ||||
|     suffix = "" | ||||
|     if image_arch and image_arch != "amd64": | ||||
|         suffix = f"-{image_arch}" | ||||
|     for name in image_names: | ||||
|         image_tags += [ | ||||
|             f"{name}:gh-{sha}{suffix}",  # Used for ArgoCD and PR comments | ||||
|             f"{name}:gh-{safe_branch_name}{suffix}",  # For convenience | ||||
|             f"{name}:gh-{safe_branch_name}-{int(time())}-{sha[:7]}{suffix}",  # Use by FluxCD | ||||
|         ] | ||||
|  | ||||
| image_main_tag = image_tags[0] | ||||
| image_tags_rendered = ",".join(image_tags) | ||||
|  | ||||
| with open(os.environ["GITHUB_OUTPUT"], "a+", encoding="utf-8") as _output: | ||||
|     print("shouldBuild=%s" % should_build, file=_output) | ||||
|     print("sha=%s" % sha, file=_output) | ||||
|     print("version=%s" % version, file=_output) | ||||
|     print("prerelease=%s" % prerelease, file=_output) | ||||
|     print("imageTags=%s" % image_tags_rendered, file=_output) | ||||
|     print("imageMainTag=%s" % image_main_tag, file=_output) | ||||
							
								
								
									
										7
									
								
								.github/actions/docker-push-variables/test.sh
									
									
									
									
										vendored
									
									
										Executable file
									
								
							
							
						
						
									
										7
									
								
								.github/actions/docker-push-variables/test.sh
									
									
									
									
										vendored
									
									
										Executable file
									
								
							| @ -0,0 +1,7 @@ | ||||
| #!/bin/bash -x | ||||
| SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) | ||||
| GITHUB_OUTPUT=/dev/stdout \ | ||||
|     GITHUB_REF=ref \ | ||||
|     GITHUB_SHA=sha \ | ||||
|     IMAGE_NAME=ghcr.io/goauthentik/server,beryju/authentik \ | ||||
|     python $SCRIPT_DIR/push_vars.py | ||||
							
								
								
									
										1
									
								
								.github/codespell-words.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/codespell-words.txt
									
									
									
									
										vendored
									
									
								
							| @ -3,3 +3,4 @@ keypairs | ||||
| hass | ||||
| warmup | ||||
| ontext | ||||
| singed | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/pull_request_template.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/pull_request_template.md
									
									
									
									
										vendored
									
									
								
							| @ -27,7 +27,6 @@ If an API change has been made | ||||
| If changes to the frontend have been made | ||||
|  | ||||
| -   [ ] The code has been formatted (`make web`) | ||||
| -   [ ] The translation files have been updated (`make i18n-extract`) | ||||
|  | ||||
| If applicable | ||||
|  | ||||
|  | ||||
							
								
								
									
										81
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										81
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,3 +1,4 @@ | ||||
| --- | ||||
| name: authentik-ci-main | ||||
|  | ||||
| on: | ||||
| @ -7,7 +8,7 @@ on: | ||||
|       - next | ||||
|       - version-* | ||||
|     paths-ignore: | ||||
|       - website | ||||
|       - website/** | ||||
|   pull_request: | ||||
|     branches: | ||||
|       - main | ||||
| @ -29,7 +30,7 @@ jobs: | ||||
|           - codespell | ||||
|           - isort | ||||
|           - pending-migrations | ||||
|           - pylint | ||||
|           # - pylint | ||||
|           - pyright | ||||
|           - ruff | ||||
|     runs-on: ubuntu-latest | ||||
| @ -69,7 +70,7 @@ jobs: | ||||
|           cp authentik/lib/default.yml local.env.yml | ||||
|           cp -R .github .. | ||||
|           cp -R scripts .. | ||||
|           git checkout version/$(python -c "from authentik import __version__; print(__version__)") | ||||
|           git checkout $(git tag --sort=version:refname | grep '^version/' | grep -vE -- '-rc[0-9]+$' | tail -n1) | ||||
|           rm -rf .github/ scripts/ | ||||
|           mv ../.github ../scripts . | ||||
|       - name: Setup authentik env (stable) | ||||
| @ -134,7 +135,7 @@ jobs: | ||||
|       - name: Setup authentik env | ||||
|         uses: ./.github/actions/setup | ||||
|       - name: Create k8s Kind Cluster | ||||
|         uses: helm/kind-action@v1.8.0 | ||||
|         uses: helm/kind-action@v1.9.0 | ||||
|       - name: run integration | ||||
|         run: | | ||||
|           poetry run coverage run manage.py test tests/integration | ||||
| @ -206,6 +207,12 @@ jobs: | ||||
|     steps: | ||||
|       - run: echo mark | ||||
|   build: | ||||
|     strategy: | ||||
|       fail-fast: false | ||||
|       matrix: | ||||
|         arch: | ||||
|           - amd64 | ||||
|           - arm64 | ||||
|     needs: ci-core-mark | ||||
|     runs-on: ubuntu-latest | ||||
|     permissions: | ||||
| @ -225,9 +232,12 @@ jobs: | ||||
|         id: ev | ||||
|         env: | ||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||
|         with: | ||||
|           image-name: ghcr.io/goauthentik/dev-server | ||||
|           image-arch: ${{ matrix.arch }} | ||||
|       - name: Login to Container Registry | ||||
|         uses: docker/login-action@v3 | ||||
|         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||
|         uses: docker/login-action@v3 | ||||
|         with: | ||||
|           registry: ghcr.io | ||||
|           username: ${{ github.repository_owner }} | ||||
| @ -241,69 +251,16 @@ jobs: | ||||
|           secrets: | | ||||
|             GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} | ||||
|             GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }} | ||||
|           tags: ${{ steps.ev.outputs.imageTags }} | ||||
|           push: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||
|           tags: | | ||||
|             ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.branchNameContainer }} | ||||
|             ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.sha }} | ||||
|             ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }} | ||||
|           build-args: | | ||||
|             GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} | ||||
|             VERSION=${{ steps.ev.outputs.version }} | ||||
|             VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} | ||||
|           cache-from: type=gha | ||||
|           cache-to: type=gha,mode=max | ||||
|   build-arm64: | ||||
|     needs: ci-core-mark | ||||
|     runs-on: ubuntu-latest | ||||
|     permissions: | ||||
|       # Needed to upload contianer images to ghcr.io | ||||
|       packages: write | ||||
|     timeout-minutes: 120 | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|         with: | ||||
|           ref: ${{ github.event.pull_request.head.sha }} | ||||
|       - name: Set up QEMU | ||||
|         uses: docker/setup-qemu-action@v3.0.0 | ||||
|       - name: Set up Docker Buildx | ||||
|         uses: docker/setup-buildx-action@v3 | ||||
|       - name: prepare variables | ||||
|         uses: ./.github/actions/docker-push-variables | ||||
|         id: ev | ||||
|         env: | ||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||
|       - name: Login to Container Registry | ||||
|         uses: docker/login-action@v3 | ||||
|         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||
|         with: | ||||
|           registry: ghcr.io | ||||
|           username: ${{ github.repository_owner }} | ||||
|           password: ${{ secrets.GITHUB_TOKEN }} | ||||
|       - name: generate ts client | ||||
|         run: make gen-client-ts | ||||
|       - name: Build Docker Image | ||||
|         uses: docker/build-push-action@v5 | ||||
|         with: | ||||
|           context: . | ||||
|           secrets: | | ||||
|             GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} | ||||
|             GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }} | ||||
|           push: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||
|           tags: | | ||||
|             ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.branchNameContainer }}-arm64 | ||||
|             ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.sha }}-arm64 | ||||
|             ghcr.io/goauthentik/dev-server:gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}-arm64 | ||||
|           build-args: | | ||||
|             GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} | ||||
|             VERSION=${{ steps.ev.outputs.version }} | ||||
|             VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} | ||||
|           platforms: linux/arm64 | ||||
|           cache-from: type=gha | ||||
|           cache-to: type=gha,mode=max | ||||
|           platforms: linux/${{ matrix.arch }} | ||||
|   pr-comment: | ||||
|     needs: | ||||
|       - build | ||||
|       - build-arm64 | ||||
|     runs-on: ubuntu-latest | ||||
|     if: ${{ github.event_name == 'pull_request' }} | ||||
|     permissions: | ||||
| @ -319,7 +276,9 @@ jobs: | ||||
|         id: ev | ||||
|         env: | ||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||
|         with: | ||||
|           image-name: ghcr.io/goauthentik/dev-server | ||||
|       - name: Comment on PR | ||||
|         uses: ./.github/actions/comment-pr-instructions | ||||
|         with: | ||||
|           tag: gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }} | ||||
|           tag: gh-${{ steps.ev.outputs.imageMainTag }} | ||||
|  | ||||
							
								
								
									
										15
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										15
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,3 +1,4 @@ | ||||
| --- | ||||
| name: authentik-ci-outpost | ||||
|  | ||||
| on: | ||||
| @ -28,7 +29,7 @@ jobs: | ||||
|       - name: Generate API | ||||
|         run: make gen-client-go | ||||
|       - name: golangci-lint | ||||
|         uses: golangci/golangci-lint-action@v3 | ||||
|         uses: golangci/golangci-lint-action@v4 | ||||
|         with: | ||||
|           version: v1.54.2 | ||||
|           args: --timeout 5000s --verbose | ||||
| @ -83,9 +84,11 @@ jobs: | ||||
|         id: ev | ||||
|         env: | ||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||
|         with: | ||||
|           image-name: ghcr.io/goauthentik/dev-${{ matrix.type }} | ||||
|       - name: Login to Container Registry | ||||
|         uses: docker/login-action@v3 | ||||
|         if: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||
|         uses: docker/login-action@v3 | ||||
|         with: | ||||
|           registry: ghcr.io | ||||
|           username: ${{ github.repository_owner }} | ||||
| @ -95,15 +98,11 @@ jobs: | ||||
|       - name: Build Docker Image | ||||
|         uses: docker/build-push-action@v5 | ||||
|         with: | ||||
|           push: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||
|           tags: | | ||||
|             ghcr.io/goauthentik/dev-${{ matrix.type }}:gh-${{ steps.ev.outputs.branchNameContainer }} | ||||
|             ghcr.io/goauthentik/dev-${{ matrix.type }}:gh-${{ steps.ev.outputs.sha }} | ||||
|           tags: ${{ steps.ev.outputs.imageTags }} | ||||
|           file: ${{ matrix.type }}.Dockerfile | ||||
|           push: ${{ steps.ev.outputs.shouldBuild == 'true' }} | ||||
|           build-args: | | ||||
|             GIT_BUILD_HASH=${{ steps.ev.outputs.sha }} | ||||
|             VERSION=${{ steps.ev.outputs.version }} | ||||
|             VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} | ||||
|           platforms: linux/amd64,linux/arm64 | ||||
|           context: . | ||||
|           cache-from: type=gha | ||||
|  | ||||
							
								
								
									
										44
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										44
									
								
								.github/workflows/release-publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,3 +1,4 @@ | ||||
| --- | ||||
| name: authentik-on-release | ||||
|  | ||||
| on: | ||||
| @ -19,6 +20,10 @@ jobs: | ||||
|       - name: prepare variables | ||||
|         uses: ./.github/actions/docker-push-variables | ||||
|         id: ev | ||||
|         env: | ||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||
|         with: | ||||
|           image-name: ghcr.io/goauthentik/server,beryju/authentik | ||||
|       - name: Docker Login Registry | ||||
|         uses: docker/login-action@v3 | ||||
|         with: | ||||
| @ -38,21 +43,12 @@ jobs: | ||||
|         uses: docker/build-push-action@v5 | ||||
|         with: | ||||
|           context: . | ||||
|           push: ${{ github.event_name == 'release' }} | ||||
|           push: true | ||||
|           secrets: | | ||||
|             GEOIPUPDATE_ACCOUNT_ID=${{ secrets.GEOIPUPDATE_ACCOUNT_ID }} | ||||
|             GEOIPUPDATE_LICENSE_KEY=${{ secrets.GEOIPUPDATE_LICENSE_KEY }} | ||||
|           tags: | | ||||
|             beryju/authentik:${{ steps.ev.outputs.version }}, | ||||
|             beryju/authentik:${{ steps.ev.outputs.versionFamily }}, | ||||
|             beryju/authentik:latest, | ||||
|             ghcr.io/goauthentik/server:${{ steps.ev.outputs.version }}, | ||||
|             ghcr.io/goauthentik/server:${{ steps.ev.outputs.versionFamily }}, | ||||
|             ghcr.io/goauthentik/server:latest | ||||
|           tags: ${{ steps.ev.outputs.imageTags }} | ||||
|           platforms: linux/amd64,linux/arm64 | ||||
|           build-args: | | ||||
|             VERSION=${{ steps.ev.outputs.version }} | ||||
|             VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} | ||||
|   build-outpost: | ||||
|     runs-on: ubuntu-latest | ||||
|     permissions: | ||||
| @ -78,6 +74,10 @@ jobs: | ||||
|       - name: prepare variables | ||||
|         uses: ./.github/actions/docker-push-variables | ||||
|         id: ev | ||||
|         env: | ||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||
|         with: | ||||
|           image-name: ghcr.io/goauthentik/${{ matrix.type }},beryju/authentik-${{ matrix.type }} | ||||
|       - name: make empty clients | ||||
|         run: | | ||||
|           mkdir -p ./gen-ts-api | ||||
| @ -96,20 +96,11 @@ jobs: | ||||
|       - name: Build Docker Image | ||||
|         uses: docker/build-push-action@v5 | ||||
|         with: | ||||
|           push: ${{ github.event_name == 'release' }} | ||||
|           tags: | | ||||
|             beryju/authentik-${{ matrix.type }}:${{ steps.ev.outputs.version }}, | ||||
|             beryju/authentik-${{ matrix.type }}:${{ steps.ev.outputs.versionFamily }}, | ||||
|             beryju/authentik-${{ matrix.type }}:latest, | ||||
|             ghcr.io/goauthentik/${{ matrix.type }}:${{ steps.ev.outputs.version }}, | ||||
|             ghcr.io/goauthentik/${{ matrix.type }}:${{ steps.ev.outputs.versionFamily }}, | ||||
|             ghcr.io/goauthentik/${{ matrix.type }}:latest | ||||
|           push: true | ||||
|           tags: ${{ steps.ev.outputs.imageTags }} | ||||
|           file: ${{ matrix.type }}.Dockerfile | ||||
|           platforms: linux/amd64,linux/arm64 | ||||
|           context: . | ||||
|           build-args: | | ||||
|             VERSION=${{ steps.ev.outputs.version }} | ||||
|             VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} | ||||
|   build-outpost-binary: | ||||
|     timeout-minutes: 120 | ||||
|     runs-on: ubuntu-latest | ||||
| @ -181,15 +172,18 @@ jobs: | ||||
|       - name: prepare variables | ||||
|         uses: ./.github/actions/docker-push-variables | ||||
|         id: ev | ||||
|         env: | ||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||
|         with: | ||||
|           image-name: ghcr.io/goauthentik/server | ||||
|       - name: Get static files from docker image | ||||
|         run: | | ||||
|           docker pull ghcr.io/goauthentik/server:latest | ||||
|           container=$(docker container create ghcr.io/goauthentik/server:latest) | ||||
|           docker pull ${{ steps.ev.outputs.imageMainTag }} | ||||
|           container=$(docker container create ${{ steps.ev.outputs.imageMainTag }}) | ||||
|           docker cp ${container}:web/ . | ||||
|       - name: Create a Sentry.io release | ||||
|         uses: getsentry/action-release@v1 | ||||
|         continue-on-error: true | ||||
|         if: ${{ github.event_name == 'release' }} | ||||
|         env: | ||||
|           SENTRY_AUTH_TOKEN: ${{ secrets.SENTRY_AUTH_TOKEN }} | ||||
|           SENTRY_ORG: authentik-security-inc | ||||
|  | ||||
							
								
								
									
										17
									
								
								.github/workflows/release-tag.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										17
									
								
								.github/workflows/release-tag.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,3 +1,4 @@ | ||||
| --- | ||||
| name: authentik-on-tag | ||||
|  | ||||
| on: | ||||
| @ -28,13 +29,13 @@ jobs: | ||||
|         with: | ||||
|           app_id: ${{ secrets.GH_APP_ID }} | ||||
|           private_key: ${{ secrets.GH_APP_PRIVATE_KEY }} | ||||
|       - name: Extract version number | ||||
|         id: get_version | ||||
|         uses: actions/github-script@v7 | ||||
|       - name: prepare variables | ||||
|         uses: ./.github/actions/docker-push-variables | ||||
|         id: ev | ||||
|         env: | ||||
|           DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} | ||||
|         with: | ||||
|           github-token: ${{ steps.generate_token.outputs.token }} | ||||
|           script: | | ||||
|             return context.payload.ref.replace(/\/refs\/tags\/version\//, ''); | ||||
|           image-name: ghcr.io/goauthentik/server | ||||
|       - name: Create Release | ||||
|         id: create_release | ||||
|         uses: actions/create-release@v1.1.4 | ||||
| @ -42,6 +43,6 @@ jobs: | ||||
|           GITHUB_TOKEN: ${{ steps.generate_token.outputs.token }} | ||||
|         with: | ||||
|           tag_name: ${{ github.ref }} | ||||
|           release_name: Release ${{ steps.get_version.outputs.result }} | ||||
|           release_name: Release ${{ steps.ev.outputs.version }} | ||||
|           draft: true | ||||
|           prerelease: false | ||||
|           prerelease: ${{ steps.ev.outputs.prerelease == 'true' }} | ||||
|  | ||||
| @ -1,9 +1,8 @@ | ||||
| name: authentik-backend-translate-compile | ||||
| --- | ||||
| name: authentik-backend-translate-extract-compile | ||||
| on: | ||||
|   push: | ||||
|     branches: [main] | ||||
|     paths: | ||||
|       - "locale/**" | ||||
|   schedule: | ||||
|     - cron: "0 0 * * *" # every day at midnight | ||||
|   workflow_dispatch: | ||||
| 
 | ||||
| env: | ||||
| @ -25,16 +24,20 @@ jobs: | ||||
|           token: ${{ steps.generate_token.outputs.token }} | ||||
|       - name: Setup authentik env | ||||
|         uses: ./.github/actions/setup | ||||
|       - name: run extract | ||||
|         run: | | ||||
|           poetry run make i18n-extract | ||||
|       - name: run compile | ||||
|         run: poetry run ak compilemessages | ||||
|         run: | | ||||
|           poetry run ak compilemessages | ||||
|           make web-check-compile | ||||
|       - name: Create Pull Request | ||||
|         uses: peter-evans/create-pull-request@v6 | ||||
|         id: cpr | ||||
|         with: | ||||
|           token: ${{ steps.generate_token.outputs.token }} | ||||
|           branch: compile-backend-translation | ||||
|           commit-message: "core: compile backend translations" | ||||
|           title: "core: compile backend translations" | ||||
|           body: "core: compile backend translations" | ||||
|           branch: extract-compile-backend-translation | ||||
|           commit-message: "core, web: update translations" | ||||
|           title: "core, web: update translations" | ||||
|           body: "core, web: update translations" | ||||
|           delete-branch: true | ||||
|           signoff: true | ||||
| @ -37,7 +37,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api | ||||
| RUN npm run build | ||||
|  | ||||
| # Stage 3: Build go proxy | ||||
| FROM --platform=${BUILDPLATFORM} docker.io/golang:1.21.6-bookworm AS go-builder | ||||
| FROM --platform=${BUILDPLATFORM} docker.io/golang:1.22.0-bookworm AS go-builder | ||||
|  | ||||
| ARG TARGETOS | ||||
| ARG TARGETARCH | ||||
| @ -83,7 +83,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \ | ||||
|     /bin/sh -c "/usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" | ||||
|  | ||||
| # Stage 5: Python dependencies | ||||
| FROM docker.io/python:3.12.1-slim-bookworm AS python-deps | ||||
| FROM docker.io/python:3.12.2-slim-bookworm AS python-deps | ||||
|  | ||||
| WORKDIR /ak-root/poetry | ||||
|  | ||||
| @ -103,12 +103,13 @@ RUN --mount=type=bind,target=./pyproject.toml,src=./pyproject.toml \ | ||||
|     --mount=type=cache,target=/root/.cache/pip \ | ||||
|     --mount=type=cache,target=/root/.cache/pypoetry \ | ||||
|     python -m venv /ak-root/venv/ && \ | ||||
|     bash -c "source ${VENV_PATH}/bin/activate && \ | ||||
|         pip3 install --upgrade pip && \ | ||||
|         pip3 install poetry && \ | ||||
|     poetry install --only=main --no-ansi --no-interaction | ||||
|         poetry install --only=main --no-ansi --no-interaction --no-root" | ||||
|  | ||||
| # Stage 6: Run | ||||
| FROM docker.io/python:3.12.1-slim-bookworm AS final-image | ||||
| FROM docker.io/python:3.12.2-slim-bookworm AS final-image | ||||
|  | ||||
| ARG GIT_BUILD_HASH | ||||
| ARG VERSION | ||||
|  | ||||
							
								
								
									
										47
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										47
									
								
								Makefile
									
									
									
									
									
								
							| @ -5,9 +5,12 @@ PWD = $(shell pwd) | ||||
| UID = $(shell id -u) | ||||
| GID = $(shell id -g) | ||||
| NPM_VERSION = $(shell python -m scripts.npm_version) | ||||
| PY_SOURCES = authentik tests scripts lifecycle | ||||
| PY_SOURCES = authentik tests scripts lifecycle .github | ||||
| DOCKER_IMAGE ?= "authentik:test" | ||||
|  | ||||
| GEN_API_TS = "gen-ts-api" | ||||
| GEN_API_GO = "gen-go-api" | ||||
|  | ||||
| pg_user := $(shell python -m authentik.lib.config postgresql.user 2>/dev/null) | ||||
| pg_host := $(shell python -m authentik.lib.config postgresql.host 2>/dev/null) | ||||
| pg_name := $(shell python -m authentik.lib.config postgresql.name 2>/dev/null) | ||||
| @ -76,7 +79,15 @@ migrate: ## Run the Authentik Django server's migrations | ||||
| i18n-extract: core-i18n-extract web-i18n-extract  ## Extract strings that require translation into files to send to a translation service | ||||
|  | ||||
| core-i18n-extract: | ||||
| 	ak makemessages --ignore web --ignore internal --ignore web --ignore web-api --ignore website -l en | ||||
| 	ak makemessages \ | ||||
| 		--add-location file \ | ||||
| 		--no-obsolete \ | ||||
| 		--ignore web \ | ||||
| 		--ignore internal \ | ||||
| 		--ignore ${GEN_API_TS} \ | ||||
| 		--ignore ${GEN_API_GO} \ | ||||
| 		--ignore website \ | ||||
| 		-l en | ||||
|  | ||||
| install: web-install website-install core-install  ## Install all requires dependencies for `web`, `website` and `core` | ||||
|  | ||||
| @ -114,7 +125,7 @@ gen-diff:  ## (Release) generate the changelog diff between the current schema a | ||||
| 	docker run \ | ||||
| 		--rm -v ${PWD}:/local \ | ||||
| 		--user ${UID}:${GID} \ | ||||
| 		docker.io/openapitools/openapi-diff:2.1.0-beta.6 \ | ||||
| 		docker.io/openapitools/openapi-diff:2.1.0-beta.8 \ | ||||
| 		--markdown /local/diff.md \ | ||||
| 		/local/old_schema.yml /local/schema.yml | ||||
| 	rm old_schema.yml | ||||
| @ -123,11 +134,11 @@ gen-diff:  ## (Release) generate the changelog diff between the current schema a | ||||
| 	npx prettier --write diff.md | ||||
|  | ||||
| gen-clean-ts:  ## Remove generated API client for Typescript | ||||
| 	rm -rf gen-ts-api/ | ||||
| 	rm -rf web/node_modules/@goauthentik/api/ | ||||
| 	rm -rf ./${GEN_API_TS}/ | ||||
| 	rm -rf ./web/node_modules/@goauthentik/api/ | ||||
|  | ||||
| gen-clean-go:  ## Remove generated API client for Go | ||||
| 	rm -rf gen-go-api/ | ||||
| 	rm -rf ./${GEN_API_GO}/ | ||||
|  | ||||
| gen-clean: gen-clean-ts gen-clean-go  ## Remove generated API clients | ||||
|  | ||||
| @ -138,31 +149,31 @@ gen-client-ts: gen-clean-ts  ## Build and install the authentik API for Typescri | ||||
| 		docker.io/openapitools/openapi-generator-cli:v6.5.0 generate \ | ||||
| 		-i /local/schema.yml \ | ||||
| 		-g typescript-fetch \ | ||||
| 		-o /local/gen-ts-api \ | ||||
| 		-o /local/${GEN_API_TS} \ | ||||
| 		-c /local/scripts/api-ts-config.yaml \ | ||||
| 		--additional-properties=npmVersion=${NPM_VERSION} \ | ||||
| 		--git-repo-id authentik \ | ||||
| 		--git-user-id goauthentik | ||||
| 	mkdir -p web/node_modules/@goauthentik/api | ||||
| 	cd gen-ts-api && npm i | ||||
| 	\cp -rfv gen-ts-api/* web/node_modules/@goauthentik/api | ||||
| 	cd ./${GEN_API_TS} && npm i | ||||
| 	\cp -rf ./${GEN_API_TS}/* web/node_modules/@goauthentik/api | ||||
|  | ||||
| gen-client-go: gen-clean-go  ## Build and install the authentik API for Golang | ||||
| 	mkdir -p ./gen-go-api ./gen-go-api/templates | ||||
| 	wget https://raw.githubusercontent.com/goauthentik/client-go/main/config.yaml -O ./gen-go-api/config.yaml | ||||
| 	wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/README.mustache -O ./gen-go-api/templates/README.mustache | ||||
| 	wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/go.mod.mustache -O ./gen-go-api/templates/go.mod.mustache | ||||
| 	cp schema.yml ./gen-go-api/ | ||||
| 	mkdir -p ./${GEN_API_GO} ./${GEN_API_GO}/templates | ||||
| 	wget https://raw.githubusercontent.com/goauthentik/client-go/main/config.yaml -O ./${GEN_API_GO}/config.yaml | ||||
| 	wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/README.mustache -O ./${GEN_API_GO}/templates/README.mustache | ||||
| 	wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/go.mod.mustache -O ./${GEN_API_GO}/templates/go.mod.mustache | ||||
| 	cp schema.yml ./${GEN_API_GO}/ | ||||
| 	docker run \ | ||||
| 		--rm -v ${PWD}/gen-go-api:/local \ | ||||
| 		--rm -v ${PWD}/${GEN_API_GO}:/local \ | ||||
| 		--user ${UID}:${GID} \ | ||||
| 		docker.io/openapitools/openapi-generator-cli:v6.5.0 generate \ | ||||
| 		-i /local/schema.yml \ | ||||
| 		-g go \ | ||||
| 		-o /local/ \ | ||||
| 		-c /local/config.yaml | ||||
| 	go mod edit -replace goauthentik.io/api/v3=./gen-go-api | ||||
| 	rm -rf ./gen-go-api/config.yaml ./gen-go-api/templates/ | ||||
| 	go mod edit -replace goauthentik.io/api/v3=./${GEN_API_GO} | ||||
| 	rm -rf ./${GEN_API_GO}/config.yaml ./${GEN_API_GO}/templates/ | ||||
|  | ||||
| gen-dev-config:  ## Generate a local development config file | ||||
| 	python -m scripts.generate_config | ||||
| @ -176,7 +187,7 @@ gen: gen-build gen-client-ts | ||||
| web-build: web-install  ## Build the Authentik UI | ||||
| 	cd web && npm run build | ||||
|  | ||||
| web: web-lint-fix web-lint web-check-compile web-i18n-extract  ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it | ||||
| web: web-lint-fix web-lint web-check-compile  ## Automatically fix formatting issues in the Authentik UI source code, lint the code, and compile it | ||||
|  | ||||
| web-install:  ## Install the necessary libraries to build the Authentik UI | ||||
| 	cd web && npm ci | ||||
|  | ||||
| @ -3,7 +3,7 @@ | ||||
| from os import environ | ||||
| from typing import Optional | ||||
|  | ||||
| __version__ = "2023.10.7" | ||||
| __version__ = "2024.2.4" | ||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -15,7 +15,3 @@ class AuthentikAdminConfig(ManagedAppConfig): | ||||
|     label = "authentik_admin" | ||||
|     verbose_name = "authentik Admin" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_load_admin_signals(self): | ||||
|         """Load admin signals""" | ||||
|         self.import_module("authentik.admin.signals") | ||||
|  | ||||
| @ -1,35 +0,0 @@ | ||||
| """test decorators api""" | ||||
|  | ||||
| from django.urls import reverse | ||||
| from guardian.shortcuts import assign_perm | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import Application, User | ||||
| from authentik.lib.generators import generate_id | ||||
|  | ||||
|  | ||||
| class TestAPIDecorators(APITestCase): | ||||
|     """test decorators api""" | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         super().setUp() | ||||
|         self.user = User.objects.create(username="test-user") | ||||
|  | ||||
|     def test_obj_perm_denied(self): | ||||
|         """Test object perm denied""" | ||||
|         self.client.force_login(self.user) | ||||
|         app = Application.objects.create(name=generate_id(), slug=generate_id()) | ||||
|         response = self.client.get( | ||||
|             reverse("authentik_api:application-metrics", kwargs={"slug": app.slug}) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 403) | ||||
|  | ||||
|     def test_other_perm_denied(self): | ||||
|         """Test other perm denied""" | ||||
|         self.client.force_login(self.user) | ||||
|         app = Application.objects.create(name=generate_id(), slug=generate_id()) | ||||
|         assign_perm("authentik_core.view_application", self.user, app) | ||||
|         response = self.client.get( | ||||
|             reverse("authentik_api:application-metrics", kwargs={"slug": app.slug}) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 403) | ||||
| @ -68,7 +68,11 @@ class ConfigView(APIView): | ||||
|         """Get all capabilities this server instance supports""" | ||||
|         caps = [] | ||||
|         deb_test = settings.DEBUG or settings.TEST | ||||
|         if Path(settings.MEDIA_ROOT).is_mount() or deb_test: | ||||
|         if ( | ||||
|             CONFIG.get("storage.media.backend", "file") == "s3" | ||||
|             or Path(settings.STORAGES["default"]["OPTIONS"]["location"]).is_mount() | ||||
|             or deb_test | ||||
|         ): | ||||
|             caps.append(Capabilities.CAN_SAVE_MEDIA) | ||||
|         for processor in get_context_processors(): | ||||
|             if cap := processor.capability(): | ||||
|  | ||||
| @ -10,13 +10,13 @@ from rest_framework.response import Response | ||||
| from rest_framework.serializers import ListSerializer, ModelSerializer | ||||
| from rest_framework.viewsets import ModelViewSet | ||||
|  | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.blueprints.models import BlueprintInstance | ||||
| from authentik.blueprints.v1.importer import Importer | ||||
| from authentik.blueprints.v1.oci import OCI_PREFIX | ||||
| from authentik.blueprints.v1.tasks import apply_blueprint, blueprints_find_dict | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import JSONDictField, PassiveSerializer | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
|  | ||||
| class ManagedSerializer: | ||||
|  | ||||
| @ -21,10 +21,27 @@ class ManagedAppConfig(AppConfig): | ||||
|         self.logger = get_logger().bind(app_name=app_name) | ||||
|  | ||||
|     def ready(self) -> None: | ||||
|         self.import_related() | ||||
|         self.reconcile_global() | ||||
|         self.reconcile_tenant() | ||||
|         return super().ready() | ||||
|  | ||||
|     def import_related(self): | ||||
|         """Automatically import related modules which rely on just being imported | ||||
|         to register themselves (mainly django signals and celery tasks)""" | ||||
|  | ||||
|         def import_relative(rel_module: str): | ||||
|             try: | ||||
|                 module_name = f"{self.name}.{rel_module}" | ||||
|                 import_module(module_name) | ||||
|                 self.logger.info("Imported related module", module=module_name) | ||||
|             except ModuleNotFoundError: | ||||
|                 pass | ||||
|  | ||||
|         import_relative("checks") | ||||
|         import_relative("tasks") | ||||
|         import_relative("signals") | ||||
|  | ||||
|     def import_module(self, path: str): | ||||
|         """Load module""" | ||||
|         import_module(path) | ||||
|  | ||||
| @ -74,7 +74,7 @@ class Exporter: | ||||
|  | ||||
|  | ||||
| class FlowExporter(Exporter): | ||||
|     """Exporter customised to only return objects related to `flow`""" | ||||
|     """Exporter customized to only return objects related to `flow`""" | ||||
|  | ||||
|     flow: Flow | ||||
|     with_policies: bool | ||||
|  | ||||
| @ -39,7 +39,8 @@ from authentik.core.models import ( | ||||
|     Source, | ||||
|     UserSourceConnection, | ||||
| ) | ||||
| from authentik.enterprise.models import LicenseKey, LicenseUsage | ||||
| from authentik.enterprise.license import LicenseKey | ||||
| from authentik.enterprise.models import LicenseUsage | ||||
| from authentik.enterprise.providers.rac.models import ConnectionToken | ||||
| from authentik.events.models import SystemTask | ||||
| from authentik.events.utils import cleanse_dict | ||||
|  | ||||
| @ -3,6 +3,7 @@ | ||||
| from dataclasses import asdict, dataclass, field | ||||
| from hashlib import sha512 | ||||
| from pathlib import Path | ||||
| from sys import platform | ||||
| from typing import Optional | ||||
|  | ||||
| from dacite.core import from_dict | ||||
| @ -60,7 +61,12 @@ def start_blueprint_watcher(): | ||||
|     if _file_watcher_started: | ||||
|         return | ||||
|     observer = Observer() | ||||
|     observer.schedule(BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True) | ||||
|     kwargs = {} | ||||
|     if platform.startswith("linux"): | ||||
|         kwargs["event_filter"] = (FileCreatedEvent, FileModifiedEvent) | ||||
|     observer.schedule( | ||||
|         BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True, **kwargs | ||||
|     ) | ||||
|     observer.start() | ||||
|     _file_watcher_started = True | ||||
|  | ||||
| @ -68,23 +74,33 @@ def start_blueprint_watcher(): | ||||
| class BlueprintEventHandler(FileSystemEventHandler): | ||||
|     """Event handler for blueprint events""" | ||||
|  | ||||
|     def on_any_event(self, event: FileSystemEvent): | ||||
|         if not isinstance(event, (FileCreatedEvent, FileModifiedEvent)): | ||||
|             return | ||||
|     # We only ever get creation and modification events. | ||||
|     # See the creation of the Observer instance above for the event filtering. | ||||
|  | ||||
|     # Even though we filter to only get file events, we might still get | ||||
|     # directory events as some implementations such as inotify do not support | ||||
|     # filtering on file/directory. | ||||
|  | ||||
|     def dispatch(self, event: FileSystemEvent) -> None: | ||||
|         """Call specific event handler method. Ignores directory changes.""" | ||||
|         if event.is_directory: | ||||
|             return | ||||
|             return None | ||||
|         return super().dispatch(event) | ||||
|  | ||||
|     def on_created(self, event: FileSystemEvent): | ||||
|         """Process file creation""" | ||||
|         LOGGER.debug("new blueprint file created, starting discovery") | ||||
|         for tenant in Tenant.objects.filter(ready=True): | ||||
|             with tenant: | ||||
|                 blueprints_discovery.delay() | ||||
|  | ||||
|     def on_modified(self, event: FileSystemEvent): | ||||
|         """Process file modification""" | ||||
|         path = Path(event.src_path) | ||||
|         root = Path(CONFIG.get("blueprints_dir")).absolute() | ||||
|         path = Path(event.src_path).absolute() | ||||
|         rel_path = str(path.relative_to(root)) | ||||
|         for tenant in Tenant.objects.filter(ready=True): | ||||
|             with tenant: | ||||
|                 root = Path(CONFIG.get("blueprints_dir")).absolute() | ||||
|                 path = Path(event.src_path).absolute() | ||||
|                 rel_path = str(path.relative_to(root)) | ||||
|                 if isinstance(event, FileCreatedEvent): | ||||
|                     LOGGER.debug("new blueprint file created, starting discovery", path=rel_path) | ||||
|                     blueprints_discovery.delay(rel_path) | ||||
|                 if isinstance(event, FileModifiedEvent): | ||||
|                 for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True): | ||||
|                     LOGGER.debug("modified blueprint file, starting apply", instance=instance) | ||||
|                     apply_blueprint.delay(instance.pk.hex) | ||||
|  | ||||
| @ -9,6 +9,7 @@ from sentry_sdk.hub import Hub | ||||
|  | ||||
| from authentik import get_full_version | ||||
| from authentik.brands.models import Brand | ||||
| from authentik.tenants.models import Tenant | ||||
|  | ||||
| _q_default = Q(default=True) | ||||
| DEFAULT_BRAND = Brand(domain="fallback") | ||||
| @ -30,13 +31,14 @@ def get_brand_for_request(request: HttpRequest) -> Brand: | ||||
| def context_processor(request: HttpRequest) -> dict[str, Any]: | ||||
|     """Context Processor that injects brand object into every template""" | ||||
|     brand = getattr(request, "brand", DEFAULT_BRAND) | ||||
|     tenant = getattr(request, "tenant", Tenant()) | ||||
|     trace = "" | ||||
|     span = Hub.current.scope.span | ||||
|     if span: | ||||
|         trace = span.to_traceparent() | ||||
|     return { | ||||
|         "brand": brand, | ||||
|         "footer_links": request.tenant.footer_links, | ||||
|         "footer_links": tenant.footer_links, | ||||
|         "sentry_trace": trace, | ||||
|         "version": get_full_version(), | ||||
|     } | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
|  | ||||
| from copy import copy | ||||
| from datetime import timedelta | ||||
| from typing import Optional | ||||
| from typing import Iterator, Optional | ||||
|  | ||||
| from django.core.cache import cache | ||||
| from django.db.models import QuerySet | ||||
| @ -23,7 +23,6 @@ from structlog.stdlib import get_logger | ||||
| from structlog.testing import capture_logs | ||||
|  | ||||
| from authentik.admin.api.metrics import CoordinateSerializer | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | ||||
| from authentik.core.api.providers import ProviderSerializer | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| @ -39,6 +38,7 @@ from authentik.lib.utils.file import ( | ||||
| from authentik.policies.api.exec import PolicyTestResultSerializer | ||||
| from authentik.policies.engine import PolicyEngine | ||||
| from authentik.policies.types import PolicyResult | ||||
| from authentik.rbac.decorators import permission_required | ||||
| from authentik.rbac.filters import ObjectFilter | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| @ -131,14 +131,14 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | ||||
|         return queryset | ||||
|  | ||||
|     def _get_allowed_applications( | ||||
|         self, queryset: QuerySet, user: Optional[User] = None | ||||
|         self, pagined_apps: Iterator[Application], user: Optional[User] = None | ||||
|     ) -> list[Application]: | ||||
|         applications = [] | ||||
|         request = self.request._request | ||||
|         if user: | ||||
|             request = copy(request) | ||||
|             request.user = user | ||||
|         for application in queryset: | ||||
|         for application in pagined_apps: | ||||
|             engine = PolicyEngine(application, request.user, request) | ||||
|             engine.build() | ||||
|             if engine.passing: | ||||
| @ -215,7 +215,7 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | ||||
|             return super().list(request) | ||||
|  | ||||
|         queryset = self._filter_queryset_for_list(self.get_queryset()) | ||||
|         self.paginate_queryset(queryset) | ||||
|         pagined_apps = self.paginate_queryset(queryset) | ||||
|  | ||||
|         if "for_user" in request.query_params: | ||||
|             try: | ||||
| @ -229,18 +229,18 @@ class ApplicationViewSet(UsedByMixin, ModelViewSet): | ||||
|                     raise ValidationError({"for_user": "User not found"}) | ||||
|             except ValueError as exc: | ||||
|                 raise ValidationError from exc | ||||
|             allowed_applications = self._get_allowed_applications(queryset, user=for_user) | ||||
|             allowed_applications = self._get_allowed_applications(pagined_apps, user=for_user) | ||||
|             serializer = self.get_serializer(allowed_applications, many=True) | ||||
|             return self.get_paginated_response(serializer.data) | ||||
|  | ||||
|         allowed_applications = [] | ||||
|         if not should_cache: | ||||
|             allowed_applications = self._get_allowed_applications(queryset) | ||||
|             allowed_applications = self._get_allowed_applications(pagined_apps) | ||||
|         if should_cache: | ||||
|             allowed_applications = cache.get(user_app_cache_key(self.request.user.pk)) | ||||
|             if not allowed_applications: | ||||
|                 LOGGER.debug("Caching allowed application list") | ||||
|                 allowed_applications = self._get_allowed_applications(queryset) | ||||
|                 allowed_applications = self._get_allowed_applications(pagined_apps) | ||||
|                 cache.set( | ||||
|                     user_app_cache_key(self.request.user.pk), | ||||
|                     allowed_applications, | ||||
|  | ||||
| @ -15,11 +15,11 @@ from rest_framework.response import Response | ||||
| from rest_framework.serializers import ListSerializer, ModelSerializer, ValidationError | ||||
| from rest_framework.viewsets import ModelViewSet | ||||
|  | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import JSONDictField, PassiveSerializer | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.rbac.api.roles import RoleSerializer | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
|  | ||||
| class GroupMemberSerializer(ModelSerializer): | ||||
|  | ||||
| @ -14,7 +14,6 @@ from rest_framework.response import Response | ||||
| from rest_framework.serializers import ModelSerializer, SerializerMethodField | ||||
| from rest_framework.viewsets import GenericViewSet | ||||
|  | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.blueprints.api import ManagedSerializer | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import MetaNameSerializer, PassiveSerializer, TypeCreateSerializer | ||||
| @ -23,6 +22,7 @@ from authentik.core.models import PropertyMapping | ||||
| from authentik.events.utils import sanitize_item | ||||
| from authentik.lib.utils.reflection import all_subclasses | ||||
| from authentik.policies.api.exec import PolicyTestSerializer | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
|  | ||||
| class PropertyMappingTestResultSerializer(PassiveSerializer): | ||||
| @ -118,7 +118,11 @@ class PropertyMappingViewSet( | ||||
|     @action(detail=True, pagination_class=None, filter_backends=[], methods=["POST"]) | ||||
|     def test(self, request: Request, pk: str) -> Response: | ||||
|         """Test Property Mapping""" | ||||
|         mapping: PropertyMapping = self.get_object() | ||||
|         _mapping: PropertyMapping = self.get_object() | ||||
|         # Use `get_subclass` to get correct class and correct `.evaluate` implementation | ||||
|         mapping = PropertyMapping.objects.get_subclass(pk=_mapping.pk) | ||||
|         # FIXME: when we separate policy mappings between ones for sources | ||||
|         # and ones for providers, we need to make the user field optional for the source mapping | ||||
|         test_params = PolicyTestSerializer(data=request.data) | ||||
|         if not test_params.is_valid(): | ||||
|             return Response(test_params.errors, status=400) | ||||
|  | ||||
| @ -16,7 +16,6 @@ from rest_framework.viewsets import GenericViewSet | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import MetaNameSerializer, TypeCreateSerializer | ||||
| @ -30,6 +29,7 @@ from authentik.lib.utils.file import ( | ||||
| ) | ||||
| from authentik.lib.utils.reflection import all_subclasses | ||||
| from authentik.policies.engine import PolicyEngine | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| @ -15,15 +15,15 @@ from rest_framework.serializers import ModelSerializer | ||||
| from rest_framework.viewsets import ModelViewSet | ||||
|  | ||||
| from authentik.api.authorization import OwnerSuperuserPermissions | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.blueprints.api import ManagedSerializer | ||||
| from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.users import UserSerializer | ||||
| from authentik.core.api.utils import PassiveSerializer | ||||
| from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents | ||||
| from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.events.utils import model_to_dict | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
|  | ||||
| class TokenSerializer(ManagedSerializer, ModelSerializer): | ||||
| @ -36,6 +36,13 @@ class TokenSerializer(ManagedSerializer, ModelSerializer): | ||||
|         if SERIALIZER_CONTEXT_BLUEPRINT in self.context: | ||||
|             self.fields["key"] = CharField(required=False) | ||||
|  | ||||
|     def validate_user(self, user: User): | ||||
|         """Ensure user of token cannot be changed""" | ||||
|         if self.instance and self.instance.user_id: | ||||
|             if user.pk != self.instance.user_id: | ||||
|                 raise ValidationError("User cannot be changed") | ||||
|         return user | ||||
|  | ||||
|     def validate(self, attrs: dict[Any, str]) -> dict[Any, str]: | ||||
|         """Ensure only API or App password tokens are created.""" | ||||
|         request: Request = self.context.get("request") | ||||
|  | ||||
| @ -49,7 +49,6 @@ from rest_framework.viewsets import ModelViewSet | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.admin.api.metrics import CoordinateSerializer | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT | ||||
| from authentik.brands.models import Brand | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| @ -74,6 +73,7 @@ from authentik.flows.models import FlowToken | ||||
| from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlanner | ||||
| from authentik.flows.views.executor import QS_KEY_TOKEN | ||||
| from authentik.lib.avatars import get_avatar | ||||
| from authentik.rbac.decorators import permission_required | ||||
| from authentik.stages.email.models import EmailStage | ||||
| from authentik.stages.email.tasks import send_mails | ||||
| from authentik.stages.email.utils import TemplateEmailMessage | ||||
| @ -154,7 +154,7 @@ class UserSerializer(ModelSerializer): | ||||
|  | ||||
|     def get_avatar(self, user: User) -> str: | ||||
|         """User's avatar, either a http/https URL or a data URI""" | ||||
|         return get_avatar(user, self.context["request"]) | ||||
|         return get_avatar(user, self.context.get("request")) | ||||
|  | ||||
|     def validate_path(self, path: str) -> str: | ||||
|         """Validate path""" | ||||
| @ -218,7 +218,7 @@ class UserSelfSerializer(ModelSerializer): | ||||
|  | ||||
|     def get_avatar(self, user: User) -> str: | ||||
|         """User's avatar, either a http/https URL or a data URI""" | ||||
|         return get_avatar(user, self.context["request"]) | ||||
|         return get_avatar(user, self.context.get("request")) | ||||
|  | ||||
|     @extend_schema_field( | ||||
|         ListSerializer( | ||||
| @ -533,7 +533,7 @@ class UserViewSet(UsedByMixin, ModelViewSet): | ||||
|             400: OpenApiResponse(description="Bad request"), | ||||
|         }, | ||||
|     ) | ||||
|     @action(detail=True, methods=["POST"]) | ||||
|     @action(detail=True, methods=["POST"], permission_classes=[]) | ||||
|     def set_password(self, request: Request, pk: int) -> Response: | ||||
|         """Set password for user""" | ||||
|         user: User = self.get_object() | ||||
| @ -611,7 +611,7 @@ class UserViewSet(UsedByMixin, ModelViewSet): | ||||
|         email_stage: EmailStage = stages.first() | ||||
|         message = TemplateEmailMessage( | ||||
|             subject=_(email_stage.subject), | ||||
|             to=[for_user.email], | ||||
|             to=[(for_user.name, for_user.email)], | ||||
|             template_name=email_stage.template, | ||||
|             language=for_user.locale(request), | ||||
|             template_context={ | ||||
| @ -631,7 +631,7 @@ class UserViewSet(UsedByMixin, ModelViewSet): | ||||
|             "401": OpenApiResponse(description="Access denied"), | ||||
|         }, | ||||
|     ) | ||||
|     @action(detail=True, methods=["POST"]) | ||||
|     @action(detail=True, methods=["POST"], permission_classes=[]) | ||||
|     def impersonate(self, request: Request, pk: int) -> Response: | ||||
|         """Impersonate a user""" | ||||
|         if not request.tenant.impersonation: | ||||
|  | ||||
| @ -14,10 +14,6 @@ class AuthentikCoreConfig(ManagedAppConfig): | ||||
|     mountpoint = "" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_load_core_signals(self): | ||||
|         """Load core signals""" | ||||
|         self.import_module("authentik.core.signals") | ||||
|  | ||||
|     def reconcile_global_debug_worker_hook(self): | ||||
|         """Dispatch startup tasks inline when debugging""" | ||||
|         if settings.DEBUG: | ||||
|  | ||||
| @ -43,7 +43,9 @@ class TokenBackend(InbuiltBackend): | ||||
|         self, request: HttpRequest, username: Optional[str], password: Optional[str], **kwargs: Any | ||||
|     ) -> Optional[User]: | ||||
|         try: | ||||
|             # pylint: disable=no-member | ||||
|             user = User._default_manager.get_by_natural_key(username) | ||||
|         # pylint: disable=no-member | ||||
|         except User.DoesNotExist: | ||||
|             # Run the default password hasher once to reduce the timing | ||||
|             # difference between an existing and a nonexistent user (#20760). | ||||
|  | ||||
| @ -37,6 +37,7 @@ def clean_expired_models(self: SystemTask): | ||||
|         messages.append(f"Expired {amount} {cls._meta.verbose_name_plural}") | ||||
|     # Special case | ||||
|     amount = 0 | ||||
|     # pylint: disable=no-member | ||||
|     for session in AuthenticatedSession.objects.all(): | ||||
|         cache_key = f"{KEY_PREFIX}{session.session_key}" | ||||
|         value = None | ||||
| @ -49,6 +50,7 @@ def clean_expired_models(self: SystemTask): | ||||
|             session.delete() | ||||
|             amount += 1 | ||||
|     LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount) | ||||
|     # pylint: disable=no-member | ||||
|     messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}") | ||||
|     self.set_status(TaskStatus.SUCCESSFUL, *messages) | ||||
|  | ||||
|  | ||||
| @ -7,8 +7,8 @@ from guardian.shortcuts import get_anonymous_user | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.api.tokens import TokenSerializer | ||||
| from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents, User | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
| from authentik.core.models import USER_ATTRIBUTE_TOKEN_EXPIRING, Token, TokenIntents | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_user | ||||
| from authentik.lib.generators import generate_id | ||||
|  | ||||
|  | ||||
| @ -17,7 +17,7 @@ class TestTokenAPI(APITestCase): | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         super().setUp() | ||||
|         self.user = User.objects.create(username="testuser") | ||||
|         self.user = create_test_user() | ||||
|         self.admin = create_test_admin_user() | ||||
|         self.client.force_login(self.user) | ||||
|  | ||||
| @ -76,6 +76,24 @@ class TestTokenAPI(APITestCase): | ||||
|         self.assertEqual(token.intent, TokenIntents.INTENT_API) | ||||
|         self.assertEqual(token.expiring, False) | ||||
|  | ||||
|     def test_token_change_user(self): | ||||
|         """Test creating a token and then changing the user""" | ||||
|         ident = generate_id() | ||||
|         response = self.client.post(reverse("authentik_api:token-list"), {"identifier": ident}) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|         token = Token.objects.get(identifier=ident) | ||||
|         self.assertEqual(token.user, self.user) | ||||
|         self.assertEqual(token.intent, TokenIntents.INTENT_API) | ||||
|         self.assertEqual(token.expiring, True) | ||||
|         self.assertTrue(self.user.has_perm("authentik_core.view_token_key", token)) | ||||
|         response = self.client.put( | ||||
|             reverse("authentik_api:token-detail", kwargs={"identifier": ident}), | ||||
|             data={"identifier": "user_token_poc_v3", "intent": "api", "user": self.admin.pk}, | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|         token.refresh_from_db() | ||||
|         self.assertEqual(token.user, self.user) | ||||
|  | ||||
|     def test_list(self): | ||||
|         """Test Token List (Test normal authentication)""" | ||||
|         Token.objects.all().delete() | ||||
|  | ||||
| @ -24,13 +24,13 @@ from rest_framework.viewsets import ModelViewSet | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.api.authorization import SecretKeyFilter | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import PassiveSerializer | ||||
| from authentik.crypto.apps import MANAGED_KEY | ||||
| from authentik.crypto.builder import CertificateBuilder | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| """authentik crypto app config""" | ||||
|  | ||||
| from datetime import datetime | ||||
| from datetime import datetime, timezone | ||||
| from typing import Optional | ||||
|  | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
| @ -17,10 +17,6 @@ class AuthentikCryptoConfig(ManagedAppConfig): | ||||
|     verbose_name = "authentik Crypto" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_load_crypto_tasks(self): | ||||
|         """Load crypto tasks""" | ||||
|         self.import_module("authentik.crypto.tasks") | ||||
|  | ||||
|     def _create_update_cert(self): | ||||
|         from authentik.crypto.builder import CertificateBuilder | ||||
|         from authentik.crypto.models import CertificateKeyPair | ||||
| @ -47,9 +43,9 @@ class AuthentikCryptoConfig(ManagedAppConfig): | ||||
|         cert: Optional[CertificateKeyPair] = CertificateKeyPair.objects.filter( | ||||
|             managed=MANAGED_KEY | ||||
|         ).first() | ||||
|         now = datetime.now() | ||||
|         now = datetime.now(tz=timezone.utc) | ||||
|         if not cert or ( | ||||
|             now < cert.certificate.not_valid_before or now > cert.certificate.not_valid_after | ||||
|             now < cert.certificate.not_valid_after_utc or now > cert.certificate.not_valid_after_utc | ||||
|         ): | ||||
|             self._create_update_cert() | ||||
|  | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| """Enterprise API Views""" | ||||
|  | ||||
| from datetime import datetime, timedelta | ||||
| from dataclasses import asdict | ||||
| from datetime import timedelta | ||||
|  | ||||
| from django.utils.timezone import now | ||||
| from django.utils.translation import gettext as _ | ||||
| @ -8,29 +9,29 @@ from drf_spectacular.types import OpenApiTypes | ||||
| from drf_spectacular.utils import extend_schema, inline_serializer | ||||
| from rest_framework.decorators import action | ||||
| from rest_framework.exceptions import ValidationError | ||||
| from rest_framework.fields import BooleanField, CharField, DateTimeField, IntegerField | ||||
| from rest_framework.fields import CharField, IntegerField | ||||
| from rest_framework.permissions import IsAuthenticated | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
| from rest_framework.serializers import ModelSerializer | ||||
| from rest_framework.viewsets import ModelViewSet | ||||
|  | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import PassiveSerializer | ||||
| from authentik.core.models import User, UserTypes | ||||
| from authentik.enterprise.models import License, LicenseKey | ||||
| from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer | ||||
| from authentik.enterprise.models import License | ||||
| from authentik.rbac.decorators import permission_required | ||||
| from authentik.root.install_id import get_install_id | ||||
|  | ||||
|  | ||||
| class EnterpriseRequiredMixin: | ||||
|     """Mixin to validate that a valid enterprise license | ||||
|     exists before allowing to safe the object""" | ||||
|     exists before allowing to save the object""" | ||||
|  | ||||
|     def validate(self, attrs: dict) -> dict: | ||||
|         """Check that a valid license exists""" | ||||
|         total = LicenseKey.get_total() | ||||
|         if not total.is_valid(): | ||||
|         if not LicenseKey.cached_summary().has_license: | ||||
|             raise ValidationError(_("Enterprise is required to create/update this object.")) | ||||
|         return super().validate(attrs) | ||||
|  | ||||
| @ -61,19 +62,6 @@ class LicenseSerializer(ModelSerializer): | ||||
|         } | ||||
|  | ||||
|  | ||||
| class LicenseSummary(PassiveSerializer): | ||||
|     """Serializer for license status""" | ||||
|  | ||||
|     internal_users = IntegerField(required=True) | ||||
|     external_users = IntegerField(required=True) | ||||
|     valid = BooleanField() | ||||
|     show_admin_warning = BooleanField() | ||||
|     show_user_warning = BooleanField() | ||||
|     read_only = BooleanField() | ||||
|     latest_valid = DateTimeField() | ||||
|     has_license = BooleanField() | ||||
|  | ||||
|  | ||||
| class LicenseForecastSerializer(PassiveSerializer): | ||||
|     """Serializer for license forecast""" | ||||
|  | ||||
| @ -111,31 +99,13 @@ class LicenseViewSet(UsedByMixin, ModelViewSet): | ||||
|     @extend_schema( | ||||
|         request=OpenApiTypes.NONE, | ||||
|         responses={ | ||||
|             200: LicenseSummary(), | ||||
|             200: LicenseSummarySerializer(), | ||||
|         }, | ||||
|     ) | ||||
|     @action(detail=False, methods=["GET"], permission_classes=[IsAuthenticated]) | ||||
|     def summary(self, request: Request) -> Response: | ||||
|         """Get the total license status""" | ||||
|         total = LicenseKey.get_total() | ||||
|         last_valid = LicenseKey.last_valid_date() | ||||
|         # TODO: move this to a different place? | ||||
|         show_admin_warning = last_valid < now() - timedelta(weeks=2) | ||||
|         show_user_warning = last_valid < now() - timedelta(weeks=4) | ||||
|         read_only = last_valid < now() - timedelta(weeks=6) | ||||
|         latest_valid = datetime.fromtimestamp(total.exp) | ||||
|         response = LicenseSummary( | ||||
|             data={ | ||||
|                 "internal_users": total.internal_users, | ||||
|                 "external_users": total.external_users, | ||||
|                 "valid": total.is_valid(), | ||||
|                 "show_admin_warning": show_admin_warning, | ||||
|                 "show_user_warning": show_user_warning, | ||||
|                 "read_only": read_only, | ||||
|                 "latest_valid": latest_valid, | ||||
|                 "has_license": License.objects.all().count() > 0, | ||||
|             } | ||||
|         ) | ||||
|         response = LicenseSummarySerializer(data=asdict(LicenseKey.cached_summary())) | ||||
|         response.is_valid(raise_exception=True) | ||||
|         return Response(response.data) | ||||
|  | ||||
|  | ||||
| @ -17,16 +17,12 @@ class AuthentikEnterpriseConfig(EnterpriseConfig): | ||||
|     verbose_name = "authentik Enterprise" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_load_enterprise_signals(self): | ||||
|         """Load enterprise signals""" | ||||
|         self.import_module("authentik.enterprise.signals") | ||||
|  | ||||
|     def enabled(self): | ||||
|         """Return true if enterprise is enabled and valid""" | ||||
|         return self.check_enabled() or settings.TEST | ||||
|  | ||||
|     def check_enabled(self): | ||||
|         """Actual enterprise check, cached""" | ||||
|         from authentik.enterprise.models import LicenseKey | ||||
|         from authentik.enterprise.license import LicenseKey | ||||
|  | ||||
|         return LicenseKey.get_total().is_valid() | ||||
|         return LicenseKey.cached_summary().valid | ||||
|  | ||||
| @ -11,7 +11,6 @@ from django.db.models.expressions import BaseExpression, Combinable | ||||
| from django.db.models.signals import post_init | ||||
| from django.http import HttpRequest | ||||
|  | ||||
| from authentik.core.models import User | ||||
| from authentik.events.middleware import AuditMiddleware, should_log_model | ||||
| from authentik.events.utils import cleanse_dict, sanitize_item | ||||
|  | ||||
| @ -19,26 +18,19 @@ from authentik.events.utils import cleanse_dict, sanitize_item | ||||
| class EnterpriseAuditMiddleware(AuditMiddleware): | ||||
|     """Enterprise audit middleware""" | ||||
|  | ||||
|     _enabled = None | ||||
|  | ||||
|     @property | ||||
|     def enabled(self): | ||||
|         """Lazy check if audit logging is enabled""" | ||||
|         if self._enabled is None: | ||||
|             self._enabled = apps.get_app_config("authentik_enterprise").enabled() | ||||
|         return self._enabled | ||||
|         """Check if audit logging is enabled""" | ||||
|         return apps.get_app_config("authentik_enterprise").enabled() | ||||
|  | ||||
|     def connect(self, request: HttpRequest): | ||||
|         super().connect(request) | ||||
|         if not self.enabled: | ||||
|             return | ||||
|         user = getattr(request, "user", self.anonymous_user) | ||||
|         if not user.is_authenticated: | ||||
|             user = self.anonymous_user | ||||
|         if not hasattr(request, "request_id"): | ||||
|             return | ||||
|         post_init.connect( | ||||
|             partial(self.post_init_handler, user=user, request=request), | ||||
|             partial(self.post_init_handler, request=request), | ||||
|             dispatch_uid=request.request_id, | ||||
|             weak=False, | ||||
|         ) | ||||
| @ -80,7 +72,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | ||||
|                 diff[key] = {"previous_value": value, "new_value": after.get(key)} | ||||
|         return sanitize_item(diff) | ||||
|  | ||||
|     def post_init_handler(self, user: User, request: HttpRequest, sender, instance: Model, **_): | ||||
|     def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_): | ||||
|         """post_init django model handler""" | ||||
|         if not should_log_model(instance): | ||||
|             return | ||||
| @ -95,7 +87,6 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | ||||
|     # pylint: disable=too-many-arguments | ||||
|     def post_save_handler( | ||||
|         self, | ||||
|         user: User, | ||||
|         request: HttpRequest, | ||||
|         sender, | ||||
|         instance: Model, | ||||
| @ -117,6 +108,4 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | ||||
|                 for field_set in ignored_field_sets: | ||||
|                     if set(diff.keys()) == set(field_set): | ||||
|                         return None | ||||
|         return super().post_save_handler( | ||||
|             user, request, sender, instance, created, thread_kwargs, **_ | ||||
|         ) | ||||
|         return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_) | ||||
|  | ||||
							
								
								
									
										214
									
								
								authentik/enterprise/license.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										214
									
								
								authentik/enterprise/license.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,214 @@ | ||||
| """Enterprise license""" | ||||
|  | ||||
| from base64 import b64decode | ||||
| from binascii import Error | ||||
| from dataclasses import asdict, dataclass, field | ||||
| from datetime import datetime, timedelta | ||||
| from enum import Enum | ||||
| from functools import lru_cache | ||||
| from time import mktime | ||||
|  | ||||
| from cryptography.exceptions import InvalidSignature | ||||
| from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate | ||||
| from dacite import from_dict | ||||
| from django.core.cache import cache | ||||
| from django.db.models.query import QuerySet | ||||
| from django.utils.timezone import now | ||||
| from jwt import PyJWTError, decode, get_unverified_header | ||||
| from rest_framework.exceptions import ValidationError | ||||
| from rest_framework.fields import BooleanField, DateTimeField, IntegerField | ||||
|  | ||||
| from authentik.core.api.utils import PassiveSerializer | ||||
| from authentik.core.models import User, UserTypes | ||||
| from authentik.enterprise.models import License, LicenseUsage | ||||
| from authentik.root.install_id import get_install_id | ||||
|  | ||||
| CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license" | ||||
| CACHE_EXPIRY_ENTERPRISE_LICENSE = 3 * 60 * 60  # 2 Hours | ||||
|  | ||||
|  | ||||
| @lru_cache() | ||||
| def get_licensing_key() -> Certificate: | ||||
|     """Get Root CA PEM""" | ||||
|     with open("authentik/enterprise/public.pem", "rb") as _key: | ||||
|         return load_pem_x509_certificate(_key.read()) | ||||
|  | ||||
|  | ||||
| def get_license_aud() -> str: | ||||
|     """Get the JWT audience field""" | ||||
|     return f"enterprise.goauthentik.io/license/{get_install_id()}" | ||||
|  | ||||
|  | ||||
| class LicenseFlags(Enum): | ||||
|     """License flags""" | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class LicenseSummary: | ||||
|     """Internal representation of a license summary""" | ||||
|  | ||||
|     internal_users: int | ||||
|     external_users: int | ||||
|     valid: bool | ||||
|     show_admin_warning: bool | ||||
|     show_user_warning: bool | ||||
|     read_only: bool | ||||
|     latest_valid: datetime | ||||
|     has_license: bool | ||||
|  | ||||
|  | ||||
| class LicenseSummarySerializer(PassiveSerializer): | ||||
|     """Serializer for license status""" | ||||
|  | ||||
|     internal_users = IntegerField(required=True) | ||||
|     external_users = IntegerField(required=True) | ||||
|     valid = BooleanField() | ||||
|     show_admin_warning = BooleanField() | ||||
|     show_user_warning = BooleanField() | ||||
|     read_only = BooleanField() | ||||
|     latest_valid = DateTimeField() | ||||
|     has_license = BooleanField() | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class LicenseKey: | ||||
|     """License JWT claims""" | ||||
|  | ||||
|     aud: str | ||||
|     exp: int | ||||
|  | ||||
|     name: str | ||||
|     internal_users: int = 0 | ||||
|     external_users: int = 0 | ||||
|     flags: list[LicenseFlags] = field(default_factory=list) | ||||
|  | ||||
|     @staticmethod | ||||
|     def validate(jwt: str) -> "LicenseKey": | ||||
|         """Validate the license from a given JWT""" | ||||
|         try: | ||||
|             headers = get_unverified_header(jwt) | ||||
|         except PyJWTError: | ||||
|             raise ValidationError("Unable to verify license") | ||||
|         x5c: list[str] = headers.get("x5c", []) | ||||
|         if len(x5c) < 1: | ||||
|             raise ValidationError("Unable to verify license") | ||||
|         try: | ||||
|             our_cert = load_der_x509_certificate(b64decode(x5c[0])) | ||||
|             intermediate = load_der_x509_certificate(b64decode(x5c[1])) | ||||
|             our_cert.verify_directly_issued_by(intermediate) | ||||
|             intermediate.verify_directly_issued_by(get_licensing_key()) | ||||
|         except (InvalidSignature, TypeError, ValueError, Error): | ||||
|             raise ValidationError("Unable to verify license") | ||||
|         try: | ||||
|             body = from_dict( | ||||
|                 LicenseKey, | ||||
|                 decode( | ||||
|                     jwt, | ||||
|                     our_cert.public_key(), | ||||
|                     algorithms=["ES512"], | ||||
|                     audience=get_license_aud(), | ||||
|                 ), | ||||
|             ) | ||||
|         except PyJWTError: | ||||
|             raise ValidationError("Unable to verify license") | ||||
|         return body | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_total() -> "LicenseKey": | ||||
|         """Get a summarized version of all (not expired) licenses""" | ||||
|         active_licenses = License.objects.filter(expiry__gte=now()) | ||||
|         total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0) | ||||
|         for lic in active_licenses: | ||||
|             total.internal_users += lic.internal_users | ||||
|             total.external_users += lic.external_users | ||||
|             exp_ts = int(mktime(lic.expiry.timetuple())) | ||||
|             if total.exp == 0: | ||||
|                 total.exp = exp_ts | ||||
|             if exp_ts <= total.exp: | ||||
|                 total.exp = exp_ts | ||||
|             total.flags.extend(lic.status.flags) | ||||
|         return total | ||||
|  | ||||
|     @staticmethod | ||||
|     def base_user_qs() -> QuerySet: | ||||
|         """Base query set for all users""" | ||||
|         return User.objects.all().exclude_anonymous().exclude(is_active=False) | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_default_user_count(): | ||||
|         """Get current default user count""" | ||||
|         return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count() | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_external_user_count(): | ||||
|         """Get current external user count""" | ||||
|         # Count since start of the month | ||||
|         last_month = now().replace(day=1) | ||||
|         return ( | ||||
|             LicenseKey.base_user_qs() | ||||
|             .filter(type=UserTypes.EXTERNAL, last_login__gte=last_month) | ||||
|             .count() | ||||
|         ) | ||||
|  | ||||
|     def is_valid(self) -> bool: | ||||
|         """Check if the given license body covers all users | ||||
|  | ||||
|         Only checks the current count, no historical data is checked""" | ||||
|         default_users = self.get_default_user_count() | ||||
|         if default_users > self.internal_users: | ||||
|             return False | ||||
|         active_users = self.get_external_user_count() | ||||
|         if active_users > self.external_users: | ||||
|             return False | ||||
|         return True | ||||
|  | ||||
|     def record_usage(self): | ||||
|         """Capture the current validity status and metrics and save them""" | ||||
|         threshold = now() - timedelta(hours=8) | ||||
|         if not LicenseUsage.objects.filter(record_date__gte=threshold).exists(): | ||||
|             LicenseUsage.objects.create( | ||||
|                 user_count=self.get_default_user_count(), | ||||
|                 external_user_count=self.get_external_user_count(), | ||||
|                 within_limits=self.is_valid(), | ||||
|             ) | ||||
|         summary = asdict(self.summary()) | ||||
|         # Also cache the latest summary for the middleware | ||||
|         cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE) | ||||
|         return summary | ||||
|  | ||||
|     @staticmethod | ||||
|     def last_valid_date() -> datetime: | ||||
|         """Get the last date the license was valid""" | ||||
|         usage: LicenseUsage = ( | ||||
|             LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first() | ||||
|         ) | ||||
|         if not usage: | ||||
|             return now() | ||||
|         return usage.record_date | ||||
|  | ||||
|     def summary(self) -> LicenseSummary: | ||||
|         """Summary of license status""" | ||||
|         has_license = License.objects.all().count() > 0 | ||||
|         last_valid = LicenseKey.last_valid_date() | ||||
|         show_admin_warning = last_valid < now() - timedelta(weeks=2) | ||||
|         show_user_warning = last_valid < now() - timedelta(weeks=4) | ||||
|         read_only = last_valid < now() - timedelta(weeks=6) | ||||
|         latest_valid = datetime.fromtimestamp(self.exp) | ||||
|         return LicenseSummary( | ||||
|             show_admin_warning=show_admin_warning and has_license, | ||||
|             show_user_warning=show_user_warning and has_license, | ||||
|             read_only=read_only and has_license, | ||||
|             latest_valid=latest_valid, | ||||
|             internal_users=self.internal_users, | ||||
|             external_users=self.external_users, | ||||
|             valid=self.is_valid(), | ||||
|             has_license=has_license, | ||||
|         ) | ||||
|  | ||||
|     @staticmethod | ||||
|     def cached_summary() -> LicenseSummary: | ||||
|         """Helper method which looks up the last summary""" | ||||
|         summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE) | ||||
|         if not summary: | ||||
|             return LicenseKey.get_total().summary() | ||||
|         return from_dict(LicenseSummary, summary) | ||||
							
								
								
									
										64
									
								
								authentik/enterprise/middleware.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								authentik/enterprise/middleware.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,64 @@ | ||||
| """Enterprise middleware""" | ||||
|  | ||||
| from collections.abc import Callable | ||||
|  | ||||
| from django.http import HttpRequest, HttpResponse, JsonResponse | ||||
| from django.urls import resolve | ||||
| from structlog.stdlib import BoundLogger, get_logger | ||||
|  | ||||
| from authentik.enterprise.api import LicenseViewSet | ||||
| from authentik.enterprise.license import LicenseKey | ||||
| from authentik.flows.views.executor import FlowExecutorView | ||||
| from authentik.lib.utils.reflection import class_to_path | ||||
|  | ||||
|  | ||||
| class EnterpriseMiddleware: | ||||
|     """Enterprise middleware""" | ||||
|  | ||||
|     get_response: Callable[[HttpRequest], HttpResponse] | ||||
|     logger: BoundLogger | ||||
|  | ||||
|     def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): | ||||
|         self.get_response = get_response | ||||
|         self.logger = get_logger().bind() | ||||
|  | ||||
|     def __call__(self, request: HttpRequest) -> HttpResponse: | ||||
|         resolver_match = resolve(request.path_info) | ||||
|         request.resolver_match = resolver_match | ||||
|         if not self.is_request_allowed(request): | ||||
|             self.logger.warning("Refusing request due to expired/invalid license") | ||||
|             return JsonResponse( | ||||
|                 { | ||||
|                     "detail": "Request denied due to expired/invalid license.", | ||||
|                     "code": "denied_license", | ||||
|                 }, | ||||
|                 status=400, | ||||
|             ) | ||||
|         return self.get_response(request) | ||||
|  | ||||
|     def is_request_allowed(self, request: HttpRequest) -> bool: | ||||
|         """Check if a specific request is allowed""" | ||||
|         if self.is_request_always_allowed(request): | ||||
|             return True | ||||
|         cached_status = LicenseKey.cached_summary() | ||||
|         if not cached_status: | ||||
|             return True | ||||
|         if cached_status.read_only: | ||||
|             return False | ||||
|         return True | ||||
|  | ||||
|     def is_request_always_allowed(self, request: HttpRequest): | ||||
|         """Check if a request is always allowed""" | ||||
|         # Always allow "safe" methods | ||||
|         if request.method.lower() in ["get", "head", "options", "trace"]: | ||||
|             return True | ||||
|         # Always allow requests to manage licenses | ||||
|         if class_to_path(request.resolver_match.func) == class_to_path(LicenseViewSet): | ||||
|             return True | ||||
|         # Flow executor is mounted as an API path but explicitly allowed | ||||
|         if class_to_path(request.resolver_match.func) == class_to_path(FlowExecutorView): | ||||
|             return True | ||||
|         # Only apply these restrictions to the API | ||||
|         if "authentik_api" not in request.resolver_match.app_names: | ||||
|             return True | ||||
|         return False | ||||
| @ -1,159 +1,20 @@ | ||||
| """Enterprise models""" | ||||
|  | ||||
| from base64 import b64decode | ||||
| from binascii import Error | ||||
| from dataclasses import dataclass, field | ||||
| from datetime import datetime, timedelta | ||||
| from enum import Enum | ||||
| from functools import lru_cache | ||||
| from time import mktime | ||||
| from datetime import timedelta | ||||
| from typing import TYPE_CHECKING | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from cryptography.exceptions import InvalidSignature | ||||
| from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate | ||||
| from dacite import from_dict | ||||
| from django.contrib.postgres.indexes import HashIndex | ||||
| from django.db import models | ||||
| from django.db.models.query import QuerySet | ||||
| from django.utils.timezone import now | ||||
| from django.utils.translation import gettext as _ | ||||
| from jwt import PyJWTError, decode, get_unverified_header | ||||
| from rest_framework.exceptions import ValidationError | ||||
| from rest_framework.serializers import BaseSerializer | ||||
|  | ||||
| from authentik.core.models import ExpiringModel, User, UserTypes | ||||
| from authentik.core.models import ExpiringModel | ||||
| from authentik.lib.models import SerializerModel | ||||
| from authentik.root.install_id import get_install_id | ||||
|  | ||||
|  | ||||
| @lru_cache() | ||||
| def get_licensing_key() -> Certificate: | ||||
|     """Get Root CA PEM""" | ||||
|     with open("authentik/enterprise/public.pem", "rb") as _key: | ||||
|         return load_pem_x509_certificate(_key.read()) | ||||
|  | ||||
|  | ||||
| def get_license_aud() -> str: | ||||
|     """Get the JWT audience field""" | ||||
|     return f"enterprise.goauthentik.io/license/{get_install_id()}" | ||||
|  | ||||
|  | ||||
| class LicenseFlags(Enum): | ||||
|     """License flags""" | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class LicenseKey: | ||||
|     """License JWT claims""" | ||||
|  | ||||
|     aud: str | ||||
|     exp: int | ||||
|  | ||||
|     name: str | ||||
|     internal_users: int = 0 | ||||
|     external_users: int = 0 | ||||
|     flags: list[LicenseFlags] = field(default_factory=list) | ||||
|  | ||||
|     @staticmethod | ||||
|     def validate(jwt: str) -> "LicenseKey": | ||||
|         """Validate the license from a given JWT""" | ||||
|         try: | ||||
|             headers = get_unverified_header(jwt) | ||||
|         except PyJWTError: | ||||
|             raise ValidationError("Unable to verify license") | ||||
|         x5c: list[str] = headers.get("x5c", []) | ||||
|         if len(x5c) < 1: | ||||
|             raise ValidationError("Unable to verify license") | ||||
|         try: | ||||
|             our_cert = load_der_x509_certificate(b64decode(x5c[0])) | ||||
|             intermediate = load_der_x509_certificate(b64decode(x5c[1])) | ||||
|             our_cert.verify_directly_issued_by(intermediate) | ||||
|             intermediate.verify_directly_issued_by(get_licensing_key()) | ||||
|         except (InvalidSignature, TypeError, ValueError, Error): | ||||
|             raise ValidationError("Unable to verify license") | ||||
|         try: | ||||
|             body = from_dict( | ||||
|                 LicenseKey, | ||||
|                 decode( | ||||
|                     jwt, | ||||
|                     our_cert.public_key(), | ||||
|                     algorithms=["ES512"], | ||||
|                     audience=get_license_aud(), | ||||
|                 ), | ||||
|             ) | ||||
|         except PyJWTError: | ||||
|             raise ValidationError("Unable to verify license") | ||||
|         return body | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_total() -> "LicenseKey": | ||||
|         """Get a summarized version of all (not expired) licenses""" | ||||
|         active_licenses = License.objects.filter(expiry__gte=now()) | ||||
|         total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0) | ||||
|         for lic in active_licenses: | ||||
|             total.internal_users += lic.internal_users | ||||
|             total.external_users += lic.external_users | ||||
|             exp_ts = int(mktime(lic.expiry.timetuple())) | ||||
|             if total.exp == 0: | ||||
|                 total.exp = exp_ts | ||||
|             if exp_ts <= total.exp: | ||||
|                 total.exp = exp_ts | ||||
|             total.flags.extend(lic.status.flags) | ||||
|         return total | ||||
|  | ||||
|     @staticmethod | ||||
|     def base_user_qs() -> QuerySet: | ||||
|         """Base query set for all users""" | ||||
|         return User.objects.all().exclude_anonymous().exclude(is_active=False) | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_default_user_count(): | ||||
|         """Get current default user count""" | ||||
|         return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count() | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_external_user_count(): | ||||
|         """Get current external user count""" | ||||
|         # Count since start of the month | ||||
|         last_month = now().replace(day=1) | ||||
|         return ( | ||||
|             LicenseKey.base_user_qs() | ||||
|             .filter(type=UserTypes.EXTERNAL, last_login__gte=last_month) | ||||
|             .count() | ||||
|         ) | ||||
|  | ||||
|     def is_valid(self) -> bool: | ||||
|         """Check if the given license body covers all users | ||||
|  | ||||
|         Only checks the current count, no historical data is checked""" | ||||
|         default_users = self.get_default_user_count() | ||||
|         if default_users > self.internal_users: | ||||
|             return False | ||||
|         active_users = self.get_external_user_count() | ||||
|         if active_users > self.external_users: | ||||
|             return False | ||||
|         return True | ||||
|  | ||||
|     def record_usage(self): | ||||
|         """Capture the current validity status and metrics and save them""" | ||||
|         threshold = now() - timedelta(hours=8) | ||||
|         if LicenseUsage.objects.filter(record_date__gte=threshold).exists(): | ||||
|             return | ||||
|         LicenseUsage.objects.create( | ||||
|             user_count=self.get_default_user_count(), | ||||
|             external_user_count=self.get_external_user_count(), | ||||
|             within_limits=self.is_valid(), | ||||
|         ) | ||||
|  | ||||
|     @staticmethod | ||||
|     def last_valid_date() -> datetime: | ||||
|         """Get the last date the license was valid""" | ||||
|         usage: LicenseUsage = ( | ||||
|             LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first() | ||||
|         ) | ||||
|         if not usage: | ||||
|             return now() | ||||
|         return usage.record_date | ||||
| if TYPE_CHECKING: | ||||
|     from authentik.enterprise.license import LicenseKey | ||||
|  | ||||
|  | ||||
| class License(SerializerModel): | ||||
| @ -174,8 +35,10 @@ class License(SerializerModel): | ||||
|         return LicenseSerializer | ||||
|  | ||||
|     @property | ||||
|     def status(self) -> LicenseKey: | ||||
|     def status(self) -> "LicenseKey": | ||||
|         """Get parsed license status""" | ||||
|         from authentik.enterprise.license import LicenseKey | ||||
|  | ||||
|         return LicenseKey.validate(self.key) | ||||
|  | ||||
|     class Meta: | ||||
|  | ||||
| @ -5,7 +5,7 @@ from typing import Optional | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
|  | ||||
| from authentik.core.models import User, UserTypes | ||||
| from authentik.enterprise.models import LicenseKey | ||||
| from authentik.enterprise.license import LicenseKey | ||||
| from authentik.policies.types import PolicyRequest, PolicyResult | ||||
| from authentik.policies.views import PolicyAccessView | ||||
|  | ||||
|  | ||||
							
								
								
									
										53
									
								
								authentik/enterprise/providers/rac/api/connection_tokens.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								authentik/enterprise/providers/rac/api/connection_tokens.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,53 @@ | ||||
| """RAC Provider API Views""" | ||||
|  | ||||
| from django_filters.rest_framework.backends import DjangoFilterBackend | ||||
| from rest_framework import mixins | ||||
| from rest_framework.filters import OrderingFilter, SearchFilter | ||||
| from rest_framework.serializers import ModelSerializer | ||||
| from rest_framework.viewsets import GenericViewSet | ||||
|  | ||||
| from authentik.api.authorization import OwnerFilter, OwnerSuperuserPermissions | ||||
| from authentik.core.api.groups import GroupMemberSerializer | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.enterprise.api import EnterpriseRequiredMixin | ||||
| from authentik.enterprise.providers.rac.api.endpoints import EndpointSerializer | ||||
| from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer | ||||
| from authentik.enterprise.providers.rac.models import ConnectionToken | ||||
|  | ||||
|  | ||||
| class ConnectionTokenSerializer(EnterpriseRequiredMixin, ModelSerializer): | ||||
|     """ConnectionToken Serializer""" | ||||
|  | ||||
|     provider_obj = RACProviderSerializer(source="provider", read_only=True) | ||||
|     endpoint_obj = EndpointSerializer(source="endpoint", read_only=True) | ||||
|     user = GroupMemberSerializer(source="session.user", read_only=True) | ||||
|  | ||||
|     class Meta: | ||||
|         model = ConnectionToken | ||||
|         fields = [ | ||||
|             "pk", | ||||
|             "provider", | ||||
|             "provider_obj", | ||||
|             "endpoint", | ||||
|             "endpoint_obj", | ||||
|             "user", | ||||
|         ] | ||||
|  | ||||
|  | ||||
| class ConnectionTokenViewSet( | ||||
|     mixins.RetrieveModelMixin, | ||||
|     mixins.UpdateModelMixin, | ||||
|     mixins.DestroyModelMixin, | ||||
|     UsedByMixin, | ||||
|     mixins.ListModelMixin, | ||||
|     GenericViewSet, | ||||
| ): | ||||
|     """ConnectionToken Viewset""" | ||||
|  | ||||
|     queryset = ConnectionToken.objects.all().select_related("session", "endpoint") | ||||
|     serializer_class = ConnectionTokenSerializer | ||||
|     filterset_fields = ["endpoint", "session__user", "provider"] | ||||
|     search_fields = ["endpoint__name", "provider__name"] | ||||
|     ordering = ["endpoint__name", "provider__name"] | ||||
|     permission_classes = [OwnerSuperuserPermissions] | ||||
|     filter_backends = [OwnerFilter, DjangoFilterBackend, OrderingFilter, SearchFilter] | ||||
| @ -16,7 +16,12 @@ class RACProviderSerializer(EnterpriseRequiredMixin, ProviderSerializer): | ||||
|  | ||||
|     class Meta: | ||||
|         model = RACProvider | ||||
|         fields = ProviderSerializer.Meta.fields + ["settings", "outpost_set", "connection_expiry"] | ||||
|         fields = ProviderSerializer.Meta.fields + [ | ||||
|             "settings", | ||||
|             "outpost_set", | ||||
|             "connection_expiry", | ||||
|             "delete_token_on_disconnect", | ||||
|         ] | ||||
|         extra_kwargs = ProviderSerializer.Meta.extra_kwargs | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -12,7 +12,3 @@ class AuthentikEnterpriseProviderRAC(EnterpriseConfig): | ||||
|     default = True | ||||
|     mountpoint = "" | ||||
|     ws_mountpoint = "authentik.enterprise.providers.rac.urls" | ||||
|  | ||||
|     def reconcile_global_load_rac_signals(self): | ||||
|         """Load rac signals""" | ||||
|         self.import_module("authentik.enterprise.providers.rac.signals") | ||||
|  | ||||
| @ -43,6 +43,7 @@ class RACClientConsumer(AsyncWebsocketConsumer): | ||||
|     logger: BoundLogger | ||||
|  | ||||
|     async def connect(self): | ||||
|         self.logger = get_logger() | ||||
|         await self.accept("guacamole") | ||||
|         await self.channel_layer.group_add(RAC_CLIENT_GROUP, self.channel_name) | ||||
|         await self.channel_layer.group_add( | ||||
| @ -64,9 +65,11 @@ class RACClientConsumer(AsyncWebsocketConsumer): | ||||
|     @database_sync_to_async | ||||
|     def init_outpost_connection(self): | ||||
|         """Initialize guac connection settings""" | ||||
|         self.token = ConnectionToken.filter_not_expired( | ||||
|             token=self.scope["url_route"]["kwargs"]["token"] | ||||
|         ).first() | ||||
|         self.token = ( | ||||
|             ConnectionToken.filter_not_expired(token=self.scope["url_route"]["kwargs"]["token"]) | ||||
|             .select_related("endpoint", "provider", "session", "session__user") | ||||
|             .first() | ||||
|         ) | ||||
|         if not self.token: | ||||
|             raise DenyConnection() | ||||
|         self.provider = self.token.provider | ||||
| @ -107,6 +110,9 @@ class RACClientConsumer(AsyncWebsocketConsumer): | ||||
|                 OUTPOST_GROUP_INSTANCE % {"outpost_pk": str(outpost.pk), "instance": states[0].uid}, | ||||
|                 msg, | ||||
|             ) | ||||
|         if self.provider and self.provider.delete_token_on_disconnect: | ||||
|             self.logger.info("Deleting connection token to prevent reconnect", token=self.token) | ||||
|             self.token.delete() | ||||
|  | ||||
|     async def receive(self, text_data=None, bytes_data=None): | ||||
|         """Mirror data received from client to the dest_channel_id | ||||
|  | ||||
| @ -0,0 +1,181 @@ | ||||
| # Generated by Django 5.0.1 on 2024-02-11 19:04 | ||||
|  | ||||
| import uuid | ||||
|  | ||||
| import django.db.models.deletion | ||||
| from django.db import migrations, models | ||||
|  | ||||
| import authentik.core.models | ||||
| import authentik.lib.utils.time | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     replaces = [ | ||||
|         ("authentik_providers_rac", "0001_initial"), | ||||
|         ("authentik_providers_rac", "0002_endpoint_maximum_connections"), | ||||
|         ("authentik_providers_rac", "0003_alter_connectiontoken_options_and_more"), | ||||
|     ] | ||||
|  | ||||
|     initial = True | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_core", "0032_group_roles"), | ||||
|         ("authentik_policies", "0011_policybinding_failure_result_and_more"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.CreateModel( | ||||
|             name="RACPropertyMapping", | ||||
|             fields=[ | ||||
|                 ( | ||||
|                     "propertymapping_ptr", | ||||
|                     models.OneToOneField( | ||||
|                         auto_created=True, | ||||
|                         on_delete=django.db.models.deletion.CASCADE, | ||||
|                         parent_link=True, | ||||
|                         primary_key=True, | ||||
|                         serialize=False, | ||||
|                         to="authentik_core.propertymapping", | ||||
|                     ), | ||||
|                 ), | ||||
|                 ("static_settings", models.JSONField(default=dict)), | ||||
|             ], | ||||
|             options={ | ||||
|                 "verbose_name": "RAC Property Mapping", | ||||
|                 "verbose_name_plural": "RAC Property Mappings", | ||||
|             }, | ||||
|             bases=("authentik_core.propertymapping",), | ||||
|         ), | ||||
|         migrations.CreateModel( | ||||
|             name="RACProvider", | ||||
|             fields=[ | ||||
|                 ( | ||||
|                     "provider_ptr", | ||||
|                     models.OneToOneField( | ||||
|                         auto_created=True, | ||||
|                         on_delete=django.db.models.deletion.CASCADE, | ||||
|                         parent_link=True, | ||||
|                         primary_key=True, | ||||
|                         serialize=False, | ||||
|                         to="authentik_core.provider", | ||||
|                     ), | ||||
|                 ), | ||||
|                 ("settings", models.JSONField(default=dict)), | ||||
|                 ( | ||||
|                     "auth_mode", | ||||
|                     models.TextField( | ||||
|                         choices=[("static", "Static"), ("prompt", "Prompt")], default="prompt" | ||||
|                     ), | ||||
|                 ), | ||||
|                 ( | ||||
|                     "connection_expiry", | ||||
|                     models.TextField( | ||||
|                         default="hours=8", | ||||
|                         help_text="Determines how long a session lasts. Default of 0 means that the sessions lasts until the browser is closed. (Format: hours=-1;minutes=-2;seconds=-3)", | ||||
|                         validators=[authentik.lib.utils.time.timedelta_string_validator], | ||||
|                     ), | ||||
|                 ), | ||||
|                 ( | ||||
|                     "delete_token_on_disconnect", | ||||
|                     models.BooleanField( | ||||
|                         default=False, | ||||
|                         help_text="When set to true, connection tokens will be deleted upon disconnect.", | ||||
|                     ), | ||||
|                 ), | ||||
|             ], | ||||
|             options={ | ||||
|                 "verbose_name": "RAC Provider", | ||||
|                 "verbose_name_plural": "RAC Providers", | ||||
|             }, | ||||
|             bases=("authentik_core.provider",), | ||||
|         ), | ||||
|         migrations.CreateModel( | ||||
|             name="Endpoint", | ||||
|             fields=[ | ||||
|                 ( | ||||
|                     "policybindingmodel_ptr", | ||||
|                     models.OneToOneField( | ||||
|                         auto_created=True, | ||||
|                         on_delete=django.db.models.deletion.CASCADE, | ||||
|                         parent_link=True, | ||||
|                         primary_key=True, | ||||
|                         serialize=False, | ||||
|                         to="authentik_policies.policybindingmodel", | ||||
|                     ), | ||||
|                 ), | ||||
|                 ("name", models.TextField()), | ||||
|                 ("host", models.TextField()), | ||||
|                 ( | ||||
|                     "protocol", | ||||
|                     models.TextField(choices=[("rdp", "Rdp"), ("vnc", "Vnc"), ("ssh", "Ssh")]), | ||||
|                 ), | ||||
|                 ("settings", models.JSONField(default=dict)), | ||||
|                 ( | ||||
|                     "auth_mode", | ||||
|                     models.TextField(choices=[("static", "Static"), ("prompt", "Prompt")]), | ||||
|                 ), | ||||
|                 ( | ||||
|                     "property_mappings", | ||||
|                     models.ManyToManyField( | ||||
|                         blank=True, default=None, to="authentik_core.propertymapping" | ||||
|                     ), | ||||
|                 ), | ||||
|                 ( | ||||
|                     "provider", | ||||
|                     models.ForeignKey( | ||||
|                         on_delete=django.db.models.deletion.CASCADE, | ||||
|                         to="authentik_providers_rac.racprovider", | ||||
|                     ), | ||||
|                 ), | ||||
|                 ("maximum_connections", models.IntegerField(default=1)), | ||||
|             ], | ||||
|             options={ | ||||
|                 "verbose_name": "RAC Endpoint", | ||||
|                 "verbose_name_plural": "RAC Endpoints", | ||||
|             }, | ||||
|             bases=("authentik_policies.policybindingmodel", models.Model), | ||||
|         ), | ||||
|         migrations.CreateModel( | ||||
|             name="ConnectionToken", | ||||
|             fields=[ | ||||
|                 ( | ||||
|                     "expires", | ||||
|                     models.DateTimeField(default=authentik.core.models.default_token_duration), | ||||
|                 ), | ||||
|                 ("expiring", models.BooleanField(default=True)), | ||||
|                 ( | ||||
|                     "connection_token_uuid", | ||||
|                     models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False), | ||||
|                 ), | ||||
|                 ("token", models.TextField(default=authentik.core.models.default_token_key)), | ||||
|                 ("settings", models.JSONField(default=dict)), | ||||
|                 ( | ||||
|                     "endpoint", | ||||
|                     models.ForeignKey( | ||||
|                         on_delete=django.db.models.deletion.CASCADE, | ||||
|                         to="authentik_providers_rac.endpoint", | ||||
|                     ), | ||||
|                 ), | ||||
|                 ( | ||||
|                     "provider", | ||||
|                     models.ForeignKey( | ||||
|                         on_delete=django.db.models.deletion.CASCADE, | ||||
|                         to="authentik_providers_rac.racprovider", | ||||
|                     ), | ||||
|                 ), | ||||
|                 ( | ||||
|                     "session", | ||||
|                     models.ForeignKey( | ||||
|                         on_delete=django.db.models.deletion.CASCADE, | ||||
|                         to="authentik_core.authenticatedsession", | ||||
|                     ), | ||||
|                 ), | ||||
|             ], | ||||
|             options={ | ||||
|                 "abstract": False, | ||||
|                 "verbose_name": "RAC Connection token", | ||||
|                 "verbose_name_plural": "RAC Connection tokens", | ||||
|             }, | ||||
|         ), | ||||
|     ] | ||||
| @ -0,0 +1,28 @@ | ||||
| # Generated by Django 5.0.1 on 2024-02-11 19:04 | ||||
|  | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_providers_rac", "0002_endpoint_maximum_connections"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AlterModelOptions( | ||||
|             name="connectiontoken", | ||||
|             options={ | ||||
|                 "verbose_name": "RAC Connection token", | ||||
|                 "verbose_name_plural": "RAC Connection tokens", | ||||
|             }, | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="racprovider", | ||||
|             name="delete_token_on_disconnect", | ||||
|             field=models.BooleanField( | ||||
|                 default=False, | ||||
|                 help_text="When set to true, connection tokens will be deleted upon disconnect.", | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
| @ -1,17 +1,18 @@ | ||||
| """RAC Models""" | ||||
|  | ||||
| from typing import Optional | ||||
| from typing import Any, Optional | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from deepmerge import always_merger | ||||
| from django.db import models | ||||
| from django.db.models import QuerySet | ||||
| from django.http import HttpRequest | ||||
| from django.utils.translation import gettext as _ | ||||
| from rest_framework.serializers import Serializer | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.exceptions import PropertyMappingExpressionException | ||||
| from authentik.core.models import ExpiringModel, PropertyMapping, Provider, default_token_key | ||||
| from authentik.core.models import ExpiringModel, PropertyMapping, Provider, User, default_token_key | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.lib.models import SerializerModel | ||||
| from authentik.lib.utils.time import timedelta_string_validator | ||||
| @ -51,6 +52,10 @@ class RACProvider(Provider): | ||||
|             "(Format: hours=-1;minutes=-2;seconds=-3)" | ||||
|         ), | ||||
|     ) | ||||
|     delete_token_on_disconnect = models.BooleanField( | ||||
|         default=False, | ||||
|         help_text=_("When set to true, connection tokens will be deleted upon disconnect."), | ||||
|     ) | ||||
|  | ||||
|     @property | ||||
|     def launch_url(self) -> Optional[str]: | ||||
| @ -107,6 +112,12 @@ class RACPropertyMapping(PropertyMapping): | ||||
|  | ||||
|     static_settings = models.JSONField(default=dict) | ||||
|  | ||||
|     def evaluate(self, user: Optional[User], request: Optional[HttpRequest], **kwargs) -> Any: | ||||
|         """Evaluate `self.expression` using `**kwargs` as Context.""" | ||||
|         if len(self.static_settings) > 0: | ||||
|             return self.static_settings | ||||
|         return super().evaluate(user, request, **kwargs) | ||||
|  | ||||
|     @property | ||||
|     def component(self) -> str: | ||||
|         return "ak-property-mapping-rac-form" | ||||
| @ -155,9 +166,6 @@ class ConnectionToken(ExpiringModel): | ||||
|         def mapping_evaluator(mappings: QuerySet): | ||||
|             for mapping in mappings: | ||||
|                 mapping: RACPropertyMapping | ||||
|                 if len(mapping.static_settings) > 0: | ||||
|                     always_merger.merge(settings, mapping.static_settings) | ||||
|                     continue | ||||
|                 try: | ||||
|                     mapping_settings = mapping.evaluate( | ||||
|                         self.session.user, None, endpoint=self.endpoint, provider=self.provider | ||||
| @ -191,3 +199,13 @@ class ConnectionToken(ExpiringModel): | ||||
|                 continue | ||||
|             settings[key] = str(value) | ||||
|         return settings | ||||
|  | ||||
|     def __str__(self): | ||||
|         return ( | ||||
|             f"RAC Connection token {self.session.user} to " | ||||
|             f"{self.endpoint.provider.name}/{self.endpoint.name}" | ||||
|         ) | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("RAC Connection token") | ||||
|         verbose_name_plural = _("RAC Connection tokens") | ||||
|  | ||||
| @ -45,8 +45,8 @@ def pre_delete_connection_token_disconnect(sender, instance: ConnectionToken, ** | ||||
|  | ||||
|  | ||||
| @receiver(post_save, sender=Endpoint) | ||||
| def post_save_application(sender: type[Model], instance, created: bool, **_): | ||||
|     """Clear user's application cache upon application creation""" | ||||
| def post_save_endpoint(sender: type[Model], instance, created: bool, **_): | ||||
|     """Clear user's endpoint cache upon endpoint creation""" | ||||
|     if not created:  # pragma: no cover | ||||
|         return | ||||
|  | ||||
|  | ||||
| @ -70,6 +70,7 @@ class TestEndpointsAPI(APITestCase): | ||||
|                             "authorization_flow": None, | ||||
|                             "property_mappings": [], | ||||
|                             "connection_expiry": "hours=8", | ||||
|                             "delete_token_on_disconnect": False, | ||||
|                             "component": "ak-provider-rac-form", | ||||
|                             "assigned_application_slug": self.app.slug, | ||||
|                             "assigned_application_name": self.app.name, | ||||
| @ -124,6 +125,7 @@ class TestEndpointsAPI(APITestCase): | ||||
|                             "assigned_application_slug": self.app.slug, | ||||
|                             "assigned_application_name": self.app.name, | ||||
|                             "connection_expiry": "hours=8", | ||||
|                             "delete_token_on_disconnect": False, | ||||
|                             "verbose_name": "RAC Provider", | ||||
|                             "verbose_name_plural": "RAC Providers", | ||||
|                             "meta_model_name": "authentik_providers_rac.racprovider", | ||||
| @ -152,6 +154,7 @@ class TestEndpointsAPI(APITestCase): | ||||
|                             "assigned_application_slug": self.app.slug, | ||||
|                             "assigned_application_name": self.app.name, | ||||
|                             "connection_expiry": "hours=8", | ||||
|                             "delete_token_on_disconnect": False, | ||||
|                             "verbose_name": "RAC Provider", | ||||
|                             "verbose_name_plural": "RAC Providers", | ||||
|                             "meta_model_name": "authentik_providers_rac.racprovider", | ||||
|  | ||||
| @ -11,7 +11,8 @@ from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||
| from authentik.enterprise.models import License, LicenseKey | ||||
| from authentik.enterprise.license import LicenseKey | ||||
| from authentik.enterprise.models import License | ||||
| from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.policies.denied import AccessDeniedResponse | ||||
| @ -39,7 +40,7 @@ class TestRACViews(APITestCase): | ||||
|         ) | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.enterprise.models.LicenseKey.validate", | ||||
|         "authentik.enterprise.license.LicenseKey.validate", | ||||
|         MagicMock( | ||||
|             return_value=LicenseKey( | ||||
|                 aud="", | ||||
| @ -70,7 +71,7 @@ class TestRACViews(APITestCase): | ||||
|         self.assertEqual(final_response.status_code, 200) | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.enterprise.models.LicenseKey.validate", | ||||
|         "authentik.enterprise.license.LicenseKey.validate", | ||||
|         MagicMock( | ||||
|             return_value=LicenseKey( | ||||
|                 aud="", | ||||
| @ -99,7 +100,7 @@ class TestRACViews(APITestCase): | ||||
|         self.assertIsInstance(response, AccessDeniedResponse) | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.enterprise.models.LicenseKey.validate", | ||||
|         "authentik.enterprise.license.LicenseKey.validate", | ||||
|         MagicMock( | ||||
|             return_value=LicenseKey( | ||||
|                 aud="", | ||||
|  | ||||
| @ -6,6 +6,7 @@ from django.urls import path | ||||
| from django.views.decorators.csrf import ensure_csrf_cookie | ||||
|  | ||||
| from authentik.core.channels import TokenOutpostMiddleware | ||||
| from authentik.enterprise.providers.rac.api.connection_tokens import ConnectionTokenViewSet | ||||
| from authentik.enterprise.providers.rac.api.endpoints import EndpointViewSet | ||||
| from authentik.enterprise.providers.rac.api.property_mappings import RACPropertyMappingViewSet | ||||
| from authentik.enterprise.providers.rac.api.providers import RACProviderViewSet | ||||
| @ -45,4 +46,5 @@ api_urlpatterns = [ | ||||
|     ("providers/rac", RACProviderViewSet), | ||||
|     ("propertymappings/rac", RACPropertyMappingViewSet), | ||||
|     ("rac/endpoints", EndpointViewSet), | ||||
|     ("rac/connection_tokens", ConnectionTokenViewSet), | ||||
| ] | ||||
|  | ||||
| @ -104,7 +104,8 @@ class RACFinalStage(RedirectStage): | ||||
|         # Check if we're already at the maximum connection limit | ||||
|         all_tokens = ConnectionToken.filter_not_expired( | ||||
|             endpoint=self.endpoint, | ||||
|         ).exclude(endpoint__maximum_connections__lte=-1) | ||||
|         ) | ||||
|         if self.endpoint.maximum_connections > -1: | ||||
|             if all_tokens.count() >= self.endpoint.maximum_connections: | ||||
|                 msg = [_("Maximum connection limit reached.")] | ||||
|                 # Check if any other tokens exist for the current user, and inform them | ||||
|  | ||||
| @ -5,9 +5,9 @@ from celery.schedules import crontab | ||||
| from authentik.lib.utils.time import fqdn_rand | ||||
|  | ||||
| CELERY_BEAT_SCHEDULE = { | ||||
|     "enterprise_calculate_license": { | ||||
|         "task": "authentik.enterprise.tasks.calculate_license", | ||||
|         "schedule": crontab(minute=fqdn_rand("calculate_license"), hour="*/2"), | ||||
|     "enterprise_update_usage": { | ||||
|         "task": "authentik.enterprise.tasks.enterprise_update_usage", | ||||
|         "schedule": crontab(minute=fqdn_rand("enterprise_update_usage"), hour="*/2"), | ||||
|         "options": {"queue": "authentik_scheduled"}, | ||||
|     } | ||||
| } | ||||
| @ -16,3 +16,5 @@ TENANT_APPS = [ | ||||
|     "authentik.enterprise.audit", | ||||
|     "authentik.enterprise.providers.rac", | ||||
| ] | ||||
|  | ||||
| MIDDLEWARE = ["authentik.enterprise.middleware.EnterpriseMiddleware"] | ||||
|  | ||||
| @ -2,11 +2,14 @@ | ||||
|  | ||||
| from datetime import datetime | ||||
|  | ||||
| from django.db.models.signals import pre_save | ||||
| from django.core.cache import cache | ||||
| from django.db.models.signals import post_save, pre_save | ||||
| from django.dispatch import receiver | ||||
| from django.utils.timezone import get_current_timezone | ||||
|  | ||||
| from authentik.enterprise.license import CACHE_KEY_ENTERPRISE_LICENSE | ||||
| from authentik.enterprise.models import License | ||||
| from authentik.enterprise.tasks import enterprise_update_usage | ||||
|  | ||||
|  | ||||
| @receiver(pre_save, sender=License) | ||||
| @ -17,3 +20,10 @@ def pre_save_license(sender: type[License], instance: License, **_): | ||||
|     instance.internal_users = status.internal_users | ||||
|     instance.external_users = status.external_users | ||||
|     instance.expiry = datetime.fromtimestamp(status.exp, tz=get_current_timezone()) | ||||
|  | ||||
|  | ||||
| @receiver(post_save, sender=License) | ||||
| def post_save_license(sender: type[License], instance: License, **_): | ||||
|     """Trigger license usage calculation when license is saved""" | ||||
|     cache.delete(CACHE_KEY_ENTERPRISE_LICENSE) | ||||
|     enterprise_update_usage.delay() | ||||
|  | ||||
| @ -1,10 +1,14 @@ | ||||
| """Enterprise tasks""" | ||||
|  | ||||
| from authentik.enterprise.models import LicenseKey | ||||
| from authentik.enterprise.license import LicenseKey | ||||
| from authentik.events.models import TaskStatus | ||||
| from authentik.events.system_tasks import SystemTask, prefill_task | ||||
| from authentik.root.celery import CELERY_APP | ||||
|  | ||||
|  | ||||
| @CELERY_APP.task() | ||||
| def calculate_license(): | ||||
|     """Calculate licensing status""" | ||||
| @CELERY_APP.task(bind=True, base=SystemTask) | ||||
| @prefill_task | ||||
| def enterprise_update_usage(self: SystemTask): | ||||
|     """Update enterprise license status""" | ||||
|     LicenseKey.get_total().record_usage() | ||||
|     self.set_status(TaskStatus.SUCCESSFUL) | ||||
|  | ||||
| @ -8,7 +8,8 @@ from django.test import TestCase | ||||
| from django.utils.timezone import now | ||||
| from rest_framework.exceptions import ValidationError | ||||
|  | ||||
| from authentik.enterprise.models import License, LicenseKey | ||||
| from authentik.enterprise.license import LicenseKey | ||||
| from authentik.enterprise.models import License | ||||
| from authentik.lib.generators import generate_id | ||||
|  | ||||
| _exp = int(mktime((now() + timedelta(days=3000)).timetuple())) | ||||
| @ -18,7 +19,7 @@ class TestEnterpriseLicense(TestCase): | ||||
|     """Enterprise license tests""" | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.enterprise.models.LicenseKey.validate", | ||||
|         "authentik.enterprise.license.LicenseKey.validate", | ||||
|         MagicMock( | ||||
|             return_value=LicenseKey( | ||||
|                 aud="", | ||||
| @ -41,7 +42,7 @@ class TestEnterpriseLicense(TestCase): | ||||
|             License.objects.create(key=generate_id()) | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.enterprise.models.LicenseKey.validate", | ||||
|         "authentik.enterprise.license.LicenseKey.validate", | ||||
|         MagicMock( | ||||
|             return_value=LicenseKey( | ||||
|                 aud="", | ||||
|  | ||||
| @ -12,7 +12,6 @@ from rest_framework.response import Response | ||||
| from rest_framework.serializers import ModelSerializer | ||||
| from rest_framework.viewsets import ModelViewSet | ||||
|  | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import PassiveSerializer | ||||
| from authentik.events.models import ( | ||||
| @ -24,6 +23,7 @@ from authentik.events.models import ( | ||||
|     TransportMode, | ||||
| ) | ||||
| from authentik.events.utils import get_user | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
|  | ||||
| class NotificationTransportSerializer(ModelSerializer): | ||||
|  | ||||
| @ -1,6 +1,5 @@ | ||||
| """Tasks API""" | ||||
|  | ||||
| from datetime import datetime, timezone | ||||
| from importlib import import_module | ||||
|  | ||||
| from django.contrib import messages | ||||
| @ -8,15 +7,22 @@ from django.utils.translation import gettext_lazy as _ | ||||
| from drf_spectacular.types import OpenApiTypes | ||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema | ||||
| from rest_framework.decorators import action | ||||
| from rest_framework.fields import CharField, ChoiceField, ListField, SerializerMethodField | ||||
| from rest_framework.fields import ( | ||||
|     CharField, | ||||
|     ChoiceField, | ||||
|     DateTimeField, | ||||
|     FloatField, | ||||
|     ListField, | ||||
|     SerializerMethodField, | ||||
| ) | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
| from rest_framework.serializers import ModelSerializer | ||||
| from rest_framework.viewsets import ReadOnlyModelViewSet | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.events.models import SystemTask, TaskStatus | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
| @ -28,9 +34,9 @@ class SystemTaskSerializer(ModelSerializer): | ||||
|     full_name = SerializerMethodField() | ||||
|     uid = CharField(required=False) | ||||
|     description = CharField() | ||||
|     start_timestamp = SerializerMethodField() | ||||
|     finish_timestamp = SerializerMethodField() | ||||
|     duration = SerializerMethodField() | ||||
|     start_timestamp = DateTimeField(read_only=True) | ||||
|     finish_timestamp = DateTimeField(read_only=True) | ||||
|     duration = FloatField(read_only=True) | ||||
|  | ||||
|     status = ChoiceField(choices=[(x.value, x.name) for x in TaskStatus]) | ||||
|     messages = ListField(child=CharField()) | ||||
| @ -41,18 +47,6 @@ class SystemTaskSerializer(ModelSerializer): | ||||
|             return f"{instance.name}:{instance.uid}" | ||||
|         return instance.name | ||||
|  | ||||
|     def get_start_timestamp(self, instance: SystemTask) -> datetime: | ||||
|         """Timestamp when the task started""" | ||||
|         return datetime.fromtimestamp(instance.start_timestamp, tz=timezone.utc) | ||||
|  | ||||
|     def get_finish_timestamp(self, instance: SystemTask) -> datetime: | ||||
|         """Timestamp when the task finished""" | ||||
|         return datetime.fromtimestamp(instance.finish_timestamp, tz=timezone.utc) | ||||
|  | ||||
|     def get_duration(self, instance: SystemTask) -> float: | ||||
|         """Get the duration a task took to run""" | ||||
|         return max(instance.finish_timestamp - instance.start_timestamp, 0) | ||||
|  | ||||
|     class Meta: | ||||
|         model = SystemTask | ||||
|         fields = [ | ||||
| @ -87,7 +81,7 @@ class SystemTaskViewSet(ReadOnlyModelViewSet): | ||||
|             500: OpenApiResponse(description="Failed to retry task"), | ||||
|         }, | ||||
|     ) | ||||
|     @action(detail=True, methods=["post"]) | ||||
|     @action(detail=True, methods=["POST"], permission_classes=[]) | ||||
|     def run(self, request: Request, pk=None) -> Response: | ||||
|         """Run task""" | ||||
|         task: SystemTask = self.get_object() | ||||
|  | ||||
| @ -1,9 +1,12 @@ | ||||
| """authentik events app""" | ||||
|  | ||||
| from celery.schedules import crontab | ||||
| from prometheus_client import Gauge, Histogram | ||||
|  | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
| from authentik.lib.config import CONFIG, ENV_PREFIX | ||||
| from authentik.lib.utils.reflection import path_to_class | ||||
| from authentik.root.celery import CELERY_APP | ||||
|  | ||||
| # TODO: Deprecated metric - remove in 2024.2 or later | ||||
| GAUGE_TASKS = Gauge( | ||||
| @ -15,7 +18,7 @@ GAUGE_TASKS = Gauge( | ||||
| SYSTEM_TASK_TIME = Histogram( | ||||
|     "authentik_system_tasks_time_seconds", | ||||
|     "Runtime of system tasks", | ||||
|     ["tenant"], | ||||
|     ["tenant", "task_name", "task_uid"], | ||||
| ) | ||||
| SYSTEM_TASK_STATUS = Gauge( | ||||
|     "authentik_system_tasks_status", | ||||
| @ -32,10 +35,6 @@ class AuthentikEventsConfig(ManagedAppConfig): | ||||
|     verbose_name = "authentik Events" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_load_events_signals(self): | ||||
|         """Load events signals""" | ||||
|         self.import_module("authentik.events.signals") | ||||
|  | ||||
|     def reconcile_global_check_deprecations(self): | ||||
|         """Check for config deprecations""" | ||||
|         from authentik.events.models import Event, EventAction | ||||
| @ -57,7 +56,7 @@ class AuthentikEventsConfig(ManagedAppConfig): | ||||
|                 message=msg, | ||||
|             ).save() | ||||
|  | ||||
|     def reconcile_prefill_tasks(self): | ||||
|     def reconcile_tenant_prefill_tasks(self): | ||||
|         """Prefill tasks""" | ||||
|         from authentik.events.models import SystemTask | ||||
|         from authentik.events.system_tasks import _prefill_tasks | ||||
| @ -67,3 +66,28 @@ class AuthentikEventsConfig(ManagedAppConfig): | ||||
|                 continue | ||||
|             task.save() | ||||
|             self.logger.debug("prefilled task", task_name=task.name) | ||||
|  | ||||
|     def reconcile_tenant_run_scheduled_tasks(self): | ||||
|         """Run schedule tasks which are behind schedule (only applies | ||||
|         to tasks of which we keep metrics)""" | ||||
|         from authentik.events.models import TaskStatus | ||||
|         from authentik.events.system_tasks import SystemTask as CelerySystemTask | ||||
|  | ||||
|         for task in CELERY_APP.conf["beat_schedule"].values(): | ||||
|             schedule = task["schedule"] | ||||
|             if not isinstance(schedule, crontab): | ||||
|                 continue | ||||
|             task_class: CelerySystemTask = path_to_class(task["task"]) | ||||
|             if not isinstance(task_class, CelerySystemTask): | ||||
|                 continue | ||||
|             db_task = task_class.db() | ||||
|             if not db_task: | ||||
|                 continue | ||||
|             due, _ = schedule.is_due(db_task.finish_timestamp) | ||||
|             if due or db_task.status == TaskStatus.UNKNOWN: | ||||
|                 self.logger.debug("Running past-due scheduled task", task=task["task"]) | ||||
|                 task_class.apply_async( | ||||
|                     args=task.get("args", None), | ||||
|                     kwargs=task.get("kwargs", None), | ||||
|                     **task.get("options", {}), | ||||
|                 ) | ||||
|  | ||||
| @ -82,26 +82,29 @@ class AuditMiddleware: | ||||
|  | ||||
|         self.anonymous_user = get_anonymous_user() | ||||
|  | ||||
|     def get_user(self, request: HttpRequest) -> User: | ||||
|         user = getattr(request, "user", self.anonymous_user) | ||||
|         if not user.is_authenticated: | ||||
|             return self.anonymous_user | ||||
|         return user | ||||
|  | ||||
|     def connect(self, request: HttpRequest): | ||||
|         """Connect signal for automatic logging""" | ||||
|         self._ensure_fallback_user() | ||||
|         user = getattr(request, "user", self.anonymous_user) | ||||
|         if not user.is_authenticated: | ||||
|             user = self.anonymous_user | ||||
|         if not hasattr(request, "request_id"): | ||||
|             return | ||||
|         post_save.connect( | ||||
|             partial(self.post_save_handler, user=user, request=request), | ||||
|             partial(self.post_save_handler, request=request), | ||||
|             dispatch_uid=request.request_id, | ||||
|             weak=False, | ||||
|         ) | ||||
|         pre_delete.connect( | ||||
|             partial(self.pre_delete_handler, user=user, request=request), | ||||
|             partial(self.pre_delete_handler, request=request), | ||||
|             dispatch_uid=request.request_id, | ||||
|             weak=False, | ||||
|         ) | ||||
|         m2m_changed.connect( | ||||
|             partial(self.m2m_changed_handler, user=user, request=request), | ||||
|             partial(self.m2m_changed_handler, request=request), | ||||
|             dispatch_uid=request.request_id, | ||||
|             weak=False, | ||||
|         ) | ||||
| @ -147,7 +150,6 @@ class AuditMiddleware: | ||||
|     # pylint: disable=too-many-arguments | ||||
|     def post_save_handler( | ||||
|         self, | ||||
|         user: User, | ||||
|         request: HttpRequest, | ||||
|         sender, | ||||
|         instance: Model, | ||||
| @ -158,16 +160,18 @@ class AuditMiddleware: | ||||
|         """Signal handler for all object's post_save""" | ||||
|         if not should_log_model(instance): | ||||
|             return | ||||
|         user = self.get_user(request) | ||||
|  | ||||
|         action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED | ||||
|         thread = EventNewThread(action, request, user=user, model=model_to_dict(instance)) | ||||
|         thread.kwargs.update(thread_kwargs or {}) | ||||
|         thread.run() | ||||
|  | ||||
|     def pre_delete_handler(self, user: User, request: HttpRequest, sender, instance: Model, **_): | ||||
|     def pre_delete_handler(self, request: HttpRequest, sender, instance: Model, **_): | ||||
|         """Signal handler for all object's pre_delete""" | ||||
|         if not should_log_model(instance):  # pragma: no cover | ||||
|             return | ||||
|         user = self.get_user(request) | ||||
|  | ||||
|         EventNewThread( | ||||
|             EventAction.MODEL_DELETED, | ||||
| @ -176,14 +180,13 @@ class AuditMiddleware: | ||||
|             model=model_to_dict(instance), | ||||
|         ).run() | ||||
|  | ||||
|     def m2m_changed_handler( | ||||
|         self, user: User, request: HttpRequest, sender, instance: Model, action: str, **_ | ||||
|     ): | ||||
|     def m2m_changed_handler(self, request: HttpRequest, sender, instance: Model, action: str, **_): | ||||
|         """Signal handler for all object's m2m_changed""" | ||||
|         if action not in ["pre_add", "pre_remove", "post_clear"]: | ||||
|             return | ||||
|         if not should_log_m2m(instance): | ||||
|             return | ||||
|         user = self.get_user(request) | ||||
|  | ||||
|         EventNewThread( | ||||
|             EventAction.MODEL_UPDATED, | ||||
|  | ||||
| @ -0,0 +1,68 @@ | ||||
| # Generated by Django 5.0.1 on 2024-02-07 15:42 | ||||
|  | ||||
| import uuid | ||||
|  | ||||
| import django.utils.timezone | ||||
| from django.db import migrations, models | ||||
|  | ||||
| import authentik.core.models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     replaces = [ | ||||
|         ("authentik_events", "0004_systemtask"), | ||||
|         ("authentik_events", "0005_remove_systemtask_finish_timestamp_and_more"), | ||||
|     ] | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_events", "0003_rename_tenant_event_brand"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.CreateModel( | ||||
|             name="SystemTask", | ||||
|             fields=[ | ||||
|                 ( | ||||
|                     "expires", | ||||
|                     models.DateTimeField(default=authentik.core.models.default_token_duration), | ||||
|                 ), | ||||
|                 ("expiring", models.BooleanField(default=True)), | ||||
|                 ( | ||||
|                     "uuid", | ||||
|                     models.UUIDField( | ||||
|                         default=uuid.uuid4, editable=False, primary_key=True, serialize=False | ||||
|                     ), | ||||
|                 ), | ||||
|                 ("name", models.TextField()), | ||||
|                 ("uid", models.TextField(null=True)), | ||||
|                 ( | ||||
|                     "status", | ||||
|                     models.TextField( | ||||
|                         choices=[ | ||||
|                             ("unknown", "Unknown"), | ||||
|                             ("successful", "Successful"), | ||||
|                             ("warning", "Warning"), | ||||
|                             ("error", "Error"), | ||||
|                         ] | ||||
|                     ), | ||||
|                 ), | ||||
|                 ("description", models.TextField(null=True)), | ||||
|                 ("messages", models.JSONField()), | ||||
|                 ("task_call_module", models.TextField()), | ||||
|                 ("task_call_func", models.TextField()), | ||||
|                 ("task_call_args", models.JSONField(default=list)), | ||||
|                 ("task_call_kwargs", models.JSONField(default=dict)), | ||||
|                 ("duration", models.FloatField(default=0)), | ||||
|                 ("finish_timestamp", models.DateTimeField(default=django.utils.timezone.now)), | ||||
|                 ("start_timestamp", models.DateTimeField(default=django.utils.timezone.now)), | ||||
|             ], | ||||
|             options={ | ||||
|                 "verbose_name": "System Task", | ||||
|                 "verbose_name_plural": "System Tasks", | ||||
|                 "permissions": [("run_task", "Run task")], | ||||
|                 "default_permissions": ["view"], | ||||
|                 "unique_together": {("name", "uid")}, | ||||
|             }, | ||||
|         ), | ||||
|     ] | ||||
| @ -0,0 +1,37 @@ | ||||
| # Generated by Django 5.0.1 on 2024-02-06 18:02 | ||||
|  | ||||
| import django.utils.timezone | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_events", "0004_systemtask"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.RemoveField( | ||||
|             model_name="systemtask", | ||||
|             name="finish_timestamp", | ||||
|         ), | ||||
|         migrations.RemoveField( | ||||
|             model_name="systemtask", | ||||
|             name="start_timestamp", | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="systemtask", | ||||
|             name="duration", | ||||
|             field=models.FloatField(default=0), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="systemtask", | ||||
|             name="finish_timestamp", | ||||
|             field=models.DateTimeField(default=django.utils.timezone.now), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="systemtask", | ||||
|             name="start_timestamp", | ||||
|             field=models.DateTimeField(default=django.utils.timezone.now), | ||||
|         ), | ||||
|     ] | ||||
| @ -451,6 +451,13 @@ class NotificationTransport(SerializerModel): | ||||
|  | ||||
|     def send_email(self, notification: "Notification") -> list[str]: | ||||
|         """Send notification via global email configuration""" | ||||
|         if notification.user.email.strip() == "": | ||||
|             LOGGER.info( | ||||
|                 "Discarding notification as user has no email address", | ||||
|                 user=notification.user, | ||||
|                 notification=notification, | ||||
|             ) | ||||
|             return None | ||||
|         subject_prefix = "authentik Notification: " | ||||
|         context = { | ||||
|             "key_value": { | ||||
| @ -480,7 +487,7 @@ class NotificationTransport(SerializerModel): | ||||
|             } | ||||
|         mail = TemplateEmailMessage( | ||||
|             subject=subject_prefix + context["title"], | ||||
|             to=[f"{notification.user.name} <{notification.user.email}>"], | ||||
|             to=[(notification.user.name, notification.user.email)], | ||||
|             language=notification.user.locale(), | ||||
|             template_name="email/event_notification.html", | ||||
|             template_context=context, | ||||
| @ -620,8 +627,9 @@ class SystemTask(SerializerModel, ExpiringModel): | ||||
|     name = models.TextField() | ||||
|     uid = models.TextField(null=True) | ||||
|  | ||||
|     start_timestamp = models.FloatField() | ||||
|     finish_timestamp = models.FloatField() | ||||
|     start_timestamp = models.DateTimeField(default=now) | ||||
|     finish_timestamp = models.DateTimeField(default=now) | ||||
|     duration = models.FloatField(default=0) | ||||
|  | ||||
|     status = models.TextField(choices=TaskStatus.choices) | ||||
|  | ||||
| @ -641,17 +649,18 @@ class SystemTask(SerializerModel, ExpiringModel): | ||||
|  | ||||
|     def update_metrics(self): | ||||
|         """Update prometheus metrics""" | ||||
|         duration = max(self.finish_timestamp - self.start_timestamp, 0) | ||||
|         # TODO: Deprecated metric - remove in 2024.2 or later | ||||
|         GAUGE_TASKS.labels( | ||||
|             tenant=connection.schema_name, | ||||
|             task_name=self.name, | ||||
|             task_uid=self.uid or "", | ||||
|             status=self.status.lower(), | ||||
|         ).set(duration) | ||||
|         ).set(self.duration) | ||||
|         SYSTEM_TASK_TIME.labels( | ||||
|             tenant=connection.schema_name, | ||||
|         ).observe(duration) | ||||
|             task_name=self.name, | ||||
|             task_uid=self.uid or "", | ||||
|         ).observe(self.duration) | ||||
|         SYSTEM_TASK_STATUS.labels( | ||||
|             tenant=connection.schema_name, | ||||
|             task_name=self.name, | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| """Monitored tasks""" | ||||
|  | ||||
| from datetime import timedelta | ||||
| from timeit import default_timer | ||||
| from datetime import datetime, timedelta | ||||
| from time import perf_counter | ||||
| from typing import Any, Optional | ||||
|  | ||||
| from django.utils.timezone import now | ||||
| @ -24,14 +24,17 @@ class SystemTask(TenantTask): | ||||
|     # For tasks that should only be listed if they failed, set this to False | ||||
|     save_on_success: bool | ||||
|  | ||||
|     _status: Optional[TaskStatus] | ||||
|     _status: TaskStatus | ||||
|     _messages: list[str] | ||||
|  | ||||
|     _uid: Optional[str] | ||||
|     _start: Optional[float] = None | ||||
|     # Precise start time from perf_counter | ||||
|     _start_precise: Optional[float] = None | ||||
|     _start: Optional[datetime] = None | ||||
|  | ||||
|     def __init__(self, *args, **kwargs) -> None: | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self._status = TaskStatus.SUCCESSFUL | ||||
|         self.save_on_success = True | ||||
|         self._uid = None | ||||
|         self._status = None | ||||
| @ -53,9 +56,17 @@ class SystemTask(TenantTask): | ||||
|         self._messages = [exception_to_string(exception)] | ||||
|  | ||||
|     def before_start(self, task_id, args, kwargs): | ||||
|         self._start = default_timer() | ||||
|         self._start_precise = perf_counter() | ||||
|         self._start = now() | ||||
|         return super().before_start(task_id, args, kwargs) | ||||
|  | ||||
|     def db(self) -> Optional[DBSystemTask]: | ||||
|         """Get DB object for latest task""" | ||||
|         return DBSystemTask.objects.filter( | ||||
|             name=self.__name__, | ||||
|             uid=self._uid, | ||||
|         ).first() | ||||
|  | ||||
|     # pylint: disable=too-many-arguments | ||||
|     def after_return(self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo): | ||||
|         super().after_return(status, retval, task_id, args, kwargs, einfo=einfo) | ||||
| @ -72,12 +83,13 @@ class SystemTask(TenantTask): | ||||
|             uid=self._uid, | ||||
|             defaults={ | ||||
|                 "description": self.__doc__, | ||||
|                 "start_timestamp": self._start or default_timer(), | ||||
|                 "finish_timestamp": default_timer(), | ||||
|                 "start_timestamp": self._start or now(), | ||||
|                 "finish_timestamp": now(), | ||||
|                 "duration": max(perf_counter() - self._start_precise, 0), | ||||
|                 "task_call_module": self.__module__, | ||||
|                 "task_call_func": self.__name__, | ||||
|                 "task_call_args": args, | ||||
|                 "task_call_kwargs": kwargs, | ||||
|                 "task_call_args": sanitize_item(args), | ||||
|                 "task_call_kwargs": sanitize_item(kwargs), | ||||
|                 "status": self._status, | ||||
|                 "messages": sanitize_item(self._messages), | ||||
|                 "expires": now() + timedelta(hours=self.result_timeout_hours), | ||||
| @ -96,12 +108,13 @@ class SystemTask(TenantTask): | ||||
|             uid=self._uid, | ||||
|             defaults={ | ||||
|                 "description": self.__doc__, | ||||
|                 "start_timestamp": self._start or default_timer(), | ||||
|                 "finish_timestamp": default_timer(), | ||||
|                 "start_timestamp": self._start or now(), | ||||
|                 "finish_timestamp": now(), | ||||
|                 "duration": max(perf_counter() - self._start_precise, 0), | ||||
|                 "task_call_module": self.__module__, | ||||
|                 "task_call_func": self.__name__, | ||||
|                 "task_call_args": args, | ||||
|                 "task_call_kwargs": kwargs, | ||||
|                 "task_call_args": sanitize_item(args), | ||||
|                 "task_call_kwargs": sanitize_item(kwargs), | ||||
|                 "status": self._status, | ||||
|                 "messages": sanitize_item(self._messages), | ||||
|                 "expires": now() + timedelta(hours=self.result_timeout_hours), | ||||
| @ -123,11 +136,14 @@ def prefill_task(func): | ||||
|         DBSystemTask( | ||||
|             name=func.__name__, | ||||
|             description=func.__doc__, | ||||
|             start_timestamp=now(), | ||||
|             finish_timestamp=now(), | ||||
|             status=TaskStatus.UNKNOWN, | ||||
|             messages=sanitize_item([_("Task has not been run yet.")]), | ||||
|             task_call_module=func.__module__, | ||||
|             task_call_func=func.__name__, | ||||
|             expiring=False, | ||||
|             duration=0, | ||||
|         ) | ||||
|     ) | ||||
|     return func | ||||
|  | ||||
| @ -3,9 +3,10 @@ | ||||
| from django.urls import reverse | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.models import Application, Token, TokenIntents | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.lib.generators import generate_id | ||||
|  | ||||
|  | ||||
| class TestEventsMiddleware(APITestCase): | ||||
| @ -47,3 +48,30 @@ class TestEventsMiddleware(APITestCase): | ||||
|                 context__model__name="test-delete", | ||||
|             ).exists() | ||||
|         ) | ||||
|  | ||||
|     def test_create_with_api(self): | ||||
|         """Test model creation event (with API token auth)""" | ||||
|         self.client.logout() | ||||
|         token = Token.objects.create(user=self.user, intent=TokenIntents.INTENT_API, expiring=False) | ||||
|         uid = generate_id() | ||||
|         self.client.post( | ||||
|             reverse("authentik_api:application-list"), | ||||
|             data={"name": uid, "slug": uid}, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {token.key}", | ||||
|         ) | ||||
|         self.assertTrue(Application.objects.filter(name=uid).exists()) | ||||
|         event = Event.objects.filter( | ||||
|             action=EventAction.MODEL_CREATED, | ||||
|             context__model__model_name="application", | ||||
|             context__model__app="authentik_core", | ||||
|             context__model__name=uid, | ||||
|         ).first() | ||||
|         self.assertIsNotNone(event) | ||||
|         self.assertEqual( | ||||
|             event.user, | ||||
|             { | ||||
|                 "pk": self.user.pk, | ||||
|                 "email": self.user.email, | ||||
|                 "username": self.user.username, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
| @ -15,7 +15,6 @@ from rest_framework.serializers import ModelSerializer, SerializerMethodField | ||||
| from rest_framework.viewsets import ModelViewSet | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.blueprints.v1.exporter import FlowExporter | ||||
| from authentik.blueprints.v1.importer import SERIALIZER_CONTEXT_BLUEPRINT, Importer | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| @ -33,6 +32,7 @@ from authentik.lib.utils.file import ( | ||||
|     set_file_url, | ||||
| ) | ||||
| from authentik.lib.views import bad_request_message | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| @ -31,10 +31,6 @@ class AuthentikFlowsConfig(ManagedAppConfig): | ||||
|     verbose_name = "authentik Flows" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_load_flows_signals(self): | ||||
|         """Load flows signals""" | ||||
|         self.import_module("authentik.flows.signals") | ||||
|  | ||||
|     def reconcile_global_load_stages(self): | ||||
|         """Ensure all stages are loaded""" | ||||
|         from authentik.flows.models import Stage | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| """flow views tests""" | ||||
|  | ||||
| from unittest.mock import MagicMock, PropertyMock, patch | ||||
| from urllib.parse import urlencode | ||||
|  | ||||
| from django.http import HttpRequest, HttpResponse | ||||
| from django.test.client import RequestFactory | ||||
| @ -18,7 +19,12 @@ from authentik.flows.models import ( | ||||
| from authentik.flows.planner import FlowPlan, FlowPlanner | ||||
| from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView | ||||
| from authentik.flows.tests import FlowTestCase | ||||
| from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView | ||||
| from authentik.flows.views.executor import ( | ||||
|     NEXT_ARG_NAME, | ||||
|     QS_QUERY, | ||||
|     SESSION_KEY_PLAN, | ||||
|     FlowExecutorView, | ||||
| ) | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.policies.dummy.models import DummyPolicy | ||||
| from authentik.policies.models import PolicyBinding | ||||
| @ -121,16 +127,73 @@ class TestFlowExecutor(FlowTestCase): | ||||
|         TO_STAGE_RESPONSE_MOCK, | ||||
|     ) | ||||
|     def test_invalid_flow_redirect(self): | ||||
|         """Tests that an invalid flow still redirects""" | ||||
|         """Test invalid flow with valid redirect destination""" | ||||
|         flow = create_test_flow( | ||||
|             FlowDesignation.AUTHENTICATION, | ||||
|         ) | ||||
|  | ||||
|         dest = "/unique-string" | ||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||
|         response = self.client.get(url + f"?{NEXT_ARG_NAME}={dest}") | ||||
|         response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}") | ||||
|         self.assertEqual(response.status_code, 302) | ||||
|         self.assertEqual(response.url, reverse("authentik_core:root-redirect")) | ||||
|         self.assertEqual(response.url, "/unique-string") | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.flows.views.executor.to_stage_response", | ||||
|         TO_STAGE_RESPONSE_MOCK, | ||||
|     ) | ||||
|     def test_invalid_flow_invalid_redirect(self): | ||||
|         """Test invalid flow redirect with an invalid URL""" | ||||
|         flow = create_test_flow( | ||||
|             FlowDesignation.AUTHENTICATION, | ||||
|         ) | ||||
|  | ||||
|         dest = "http://something.example.com/unique-string" | ||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||
|  | ||||
|         response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}") | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertStageResponse( | ||||
|             response, | ||||
|             flow, | ||||
|             component="ak-stage-access-denied", | ||||
|             error_message="Invalid next URL", | ||||
|         ) | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.flows.views.executor.to_stage_response", | ||||
|         TO_STAGE_RESPONSE_MOCK, | ||||
|     ) | ||||
|     def test_valid_flow_redirect(self): | ||||
|         """Test valid flow with valid redirect destination""" | ||||
|         flow = create_test_flow() | ||||
|  | ||||
|         dest = "/unique-string" | ||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||
|  | ||||
|         response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}") | ||||
|         self.assertEqual(response.status_code, 302) | ||||
|         self.assertEqual(response.url, "/unique-string") | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.flows.views.executor.to_stage_response", | ||||
|         TO_STAGE_RESPONSE_MOCK, | ||||
|     ) | ||||
|     def test_valid_flow_invalid_redirect(self): | ||||
|         """Test valid flow redirect with an invalid URL""" | ||||
|         flow = create_test_flow() | ||||
|  | ||||
|         dest = "http://something.example.com/unique-string" | ||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||
|  | ||||
|         response = self.client.get(url + f"?{QS_QUERY}={urlencode({NEXT_ARG_NAME: dest})}") | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertStageResponse( | ||||
|             response, | ||||
|             flow, | ||||
|             component="ak-stage-access-denied", | ||||
|             error_message="Invalid next URL", | ||||
|         ) | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.flows.views.executor.to_stage_response", | ||||
|  | ||||
| @ -12,6 +12,7 @@ from django.shortcuts import get_object_or_404, redirect | ||||
| from django.template.response import TemplateResponse | ||||
| from django.urls import reverse | ||||
| from django.utils.decorators import method_decorator | ||||
| from django.utils.translation import gettext as _ | ||||
| from django.views.decorators.clickjacking import xframe_options_sameorigin | ||||
| from django.views.generic import View | ||||
| from drf_spectacular.types import OpenApiTypes | ||||
| @ -178,6 +179,8 @@ class FlowExecutorView(APIView): | ||||
|                     self.cancel() | ||||
|                 self._logger.debug("f(exec): Continuing existing plan") | ||||
|  | ||||
|             # Initial flow request, check if we have an upstream query string passed in | ||||
|             request.session[SESSION_KEY_GET] = get_params | ||||
|             # Don't check session again as we've either already loaded the plan or we need to plan | ||||
|             if not self.plan: | ||||
|                 request.session[SESSION_KEY_HISTORY] = [] | ||||
| @ -192,8 +195,6 @@ class FlowExecutorView(APIView): | ||||
|                     # To match behaviour with loading an empty flow plan from cache, | ||||
|                     # we don't show an error message here, but rather call _flow_done() | ||||
|                     return self._flow_done() | ||||
|             # Initial flow request, check if we have an upstream query string passed in | ||||
|             request.session[SESSION_KEY_GET] = get_params | ||||
|             # We don't save the Plan after getting the next stage | ||||
|             # as it hasn't been successfully passed yet | ||||
|             try: | ||||
| @ -392,7 +393,11 @@ class FlowExecutorView(APIView): | ||||
|             NEXT_ARG_NAME, "authentik_core:root-redirect" | ||||
|         ) | ||||
|         self.cancel() | ||||
|         if next_param and not is_url_absolute(next_param): | ||||
|             return to_stage_response(self.request, redirect_with_qs(next_param)) | ||||
|         return to_stage_response( | ||||
|             self.request, self.stage_invalid(error_message=_("Invalid next URL")) | ||||
|         ) | ||||
|  | ||||
|     def stage_ok(self) -> HttpResponse: | ||||
|         """Callback called by stages upon successful completion. | ||||
|  | ||||
| @ -30,10 +30,6 @@ class AuthentikOutpostConfig(ManagedAppConfig): | ||||
|     verbose_name = "authentik Outpost" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_load_outposts_signals(self): | ||||
|         """Load outposts signals""" | ||||
|         self.import_module("authentik.outposts.signals") | ||||
|  | ||||
|     def reconcile_tenant_embedded_outpost(self): | ||||
|         """Ensure embedded outpost""" | ||||
|         from authentik.outposts.models import ( | ||||
|  | ||||
| @ -13,7 +13,6 @@ from rest_framework.viewsets import GenericViewSet | ||||
| from structlog.stdlib import get_logger | ||||
| from structlog.testing import capture_logs | ||||
|  | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.core.api.applications import user_app_cache_key | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import CacheSerializer, MetaNameSerializer, TypeCreateSerializer | ||||
| @ -23,6 +22,7 @@ from authentik.policies.api.exec import PolicyTestResultSerializer, PolicyTestSe | ||||
| from authentik.policies.models import Policy, PolicyBinding | ||||
| from authentik.policies.process import PolicyProcess | ||||
| from authentik.policies.types import CACHE_PREFIX, PolicyRequest | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| @ -35,7 +35,3 @@ class AuthentikPoliciesConfig(ManagedAppConfig): | ||||
|     label = "authentik_policies" | ||||
|     verbose_name = "authentik Policies" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_load_policies_signals(self): | ||||
|         """Load policies signals""" | ||||
|         self.import_module("authentik.policies.signals") | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
|  | ||||
| from multiprocessing import Pipe, current_process | ||||
| from multiprocessing.connection import Connection | ||||
| from timeit import default_timer | ||||
| from time import perf_counter | ||||
| from typing import Iterator, Optional | ||||
|  | ||||
| from django.core.cache import cache | ||||
| @ -84,10 +84,10 @@ class PolicyEngine: | ||||
|     def _check_cache(self, binding: PolicyBinding): | ||||
|         if not self.use_cache: | ||||
|             return False | ||||
|         before = default_timer() | ||||
|         before = perf_counter() | ||||
|         key = cache_key(binding, self.request) | ||||
|         cached_policy = cache.get(key, None) | ||||
|         duration = max(default_timer() - before, 0) | ||||
|         duration = max(perf_counter() - before, 0) | ||||
|         if not cached_policy: | ||||
|             return False | ||||
|         self.logger.debug( | ||||
|  | ||||
| @ -2,6 +2,8 @@ | ||||
|  | ||||
| from authentik.blueprints.apps import ManagedAppConfig | ||||
|  | ||||
| CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/" | ||||
|  | ||||
|  | ||||
| class AuthentikPolicyReputationConfig(ManagedAppConfig): | ||||
|     """Authentik reputation app config""" | ||||
| @ -10,11 +12,3 @@ class AuthentikPolicyReputationConfig(ManagedAppConfig): | ||||
|     label = "authentik_policies_reputation" | ||||
|     verbose_name = "authentik Policies.Reputation" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_load_policies_reputation_signals(self): | ||||
|         """Load policies.reputation signals""" | ||||
|         self.import_module("authentik.policies.reputation.signals") | ||||
|  | ||||
|     def reconcile_global_load_policies_reputation_tasks(self): | ||||
|         """Load policies.reputation tasks""" | ||||
|         self.import_module("authentik.policies.reputation.tasks") | ||||
|  | ||||
| @ -19,7 +19,6 @@ from authentik.policies.types import PolicyRequest, PolicyResult | ||||
| from authentik.root.middleware import ClientIPMiddleware | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| CACHE_KEY_PREFIX = "goauthentik.io/policies/reputation/scores/" | ||||
|  | ||||
|  | ||||
| def reputation_expiry(): | ||||
|  | ||||
| @ -8,7 +8,7 @@ from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.signals import login_failed | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.policies.reputation.models import CACHE_KEY_PREFIX | ||||
| from authentik.policies.reputation.apps import CACHE_KEY_PREFIX | ||||
| from authentik.policies.reputation.tasks import save_reputation | ||||
| from authentik.root.middleware import ClientIPMiddleware | ||||
| from authentik.stages.identification.signals import identification_failed | ||||
|  | ||||
| @ -7,8 +7,8 @@ from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR | ||||
| from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR | ||||
| from authentik.events.models import TaskStatus | ||||
| from authentik.events.system_tasks import SystemTask, prefill_task | ||||
| from authentik.policies.reputation.apps import CACHE_KEY_PREFIX | ||||
| from authentik.policies.reputation.models import Reputation | ||||
| from authentik.policies.reputation.signals import CACHE_KEY_PREFIX | ||||
| from authentik.root.celery import CELERY_APP | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
| @ -6,7 +6,8 @@ from django.test import RequestFactory, TestCase | ||||
| from authentik.core.models import User | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.policies.reputation.api import ReputationPolicySerializer | ||||
| from authentik.policies.reputation.models import CACHE_KEY_PREFIX, Reputation, ReputationPolicy | ||||
| from authentik.policies.reputation.apps import CACHE_KEY_PREFIX | ||||
| from authentik.policies.reputation.models import Reputation, ReputationPolicy | ||||
| from authentik.policies.reputation.tasks import save_reputation | ||||
| from authentik.policies.types import PolicyRequest | ||||
| from authentik.stages.password import BACKEND_INBUILT | ||||
|  | ||||
| @ -15,13 +15,13 @@ from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
| from rest_framework.viewsets import ModelViewSet | ||||
|  | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.core.api.providers import ProviderSerializer | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer | ||||
| from authentik.core.models import Provider | ||||
| from authentik.providers.oauth2.id_token import IDToken | ||||
| from authentik.providers.oauth2.models import AccessToken, OAuth2Provider, ScopeMapping | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
|  | ||||
| class OAuth2ProviderSerializer(ProviderSerializer): | ||||
|  | ||||
| @ -36,8 +36,21 @@ class TestAuthorize(OAuthTestCase): | ||||
|  | ||||
|     def test_invalid_grant_type(self): | ||||
|         """Test with invalid grant type""" | ||||
|         OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris="http://local.invalid/Foo", | ||||
|         ) | ||||
|         with self.assertRaises(AuthorizeError): | ||||
|             request = self.factory.get("/", data={"response_type": "invalid"}) | ||||
|             request = self.factory.get( | ||||
|                 "/", | ||||
|                 data={ | ||||
|                     "response_type": "invalid", | ||||
|                     "client_id": "test", | ||||
|                     "redirect_uri": "http://local.invalid/Foo", | ||||
|                 }, | ||||
|             ) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|  | ||||
|     def test_invalid_client_id(self): | ||||
| @ -344,7 +357,12 @@ class TestAuthorize(OAuthTestCase): | ||||
|                 ] | ||||
|             ) | ||||
|         ) | ||||
|         Application.objects.create(name="app", slug="app", provider=provider) | ||||
|         provider.property_mappings.add( | ||||
|             ScopeMapping.objects.create( | ||||
|                 name=generate_id(), scope_name="test", expression="""return {"sub": "foo"}""" | ||||
|             ) | ||||
|         ) | ||||
|         Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider) | ||||
|         state = generate_id() | ||||
|         user = create_test_admin_user() | ||||
|         self.client.force_login(user) | ||||
| @ -365,7 +383,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|                     "response_type": "id_token", | ||||
|                     "client_id": "test", | ||||
|                     "state": state, | ||||
|                     "scope": "openid", | ||||
|                     "scope": "openid test", | ||||
|                     "redirect_uri": "http://localhost", | ||||
|                     "nonce": generate_id(), | ||||
|                 }, | ||||
| @ -390,6 +408,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             ) | ||||
|             jwt = self.validate_jwt(token, provider) | ||||
|             self.assertEqual(jwt["amr"], ["pwd"]) | ||||
|             self.assertEqual(jwt["sub"], "foo") | ||||
|             self.assertAlmostEqual( | ||||
|                 jwt["exp"] - now().timestamp(), | ||||
|                 expires, | ||||
|  | ||||
| @ -4,9 +4,10 @@ from urllib.parse import urlencode | ||||
|  | ||||
| from django.urls import reverse | ||||
|  | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.models import Application, Group | ||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.policies.models import PolicyBinding | ||||
| from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | ||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||
| from authentik.providers.oauth2.views.device_init import QS_KEY_CODE | ||||
| @ -77,3 +78,23 @@ class TesOAuth2DeviceInit(OAuthTestCase): | ||||
|             + "?" | ||||
|             + urlencode({QS_KEY_CODE: token.user_code}), | ||||
|         ) | ||||
|  | ||||
|     def test_device_init_denied(self): | ||||
|         """Test device init""" | ||||
|         group = Group.objects.create(name="foo") | ||||
|         PolicyBinding.objects.create( | ||||
|             group=group, | ||||
|             target=self.application, | ||||
|             order=0, | ||||
|         ) | ||||
|         token = DeviceToken.objects.create( | ||||
|             user_code="foo", | ||||
|             provider=self.provider, | ||||
|         ) | ||||
|         res = self.client.get( | ||||
|             reverse("authentik_providers_oauth2_root:device-login") | ||||
|             + "?" | ||||
|             + urlencode({QS_KEY_CODE: token.user_code}) | ||||
|         ) | ||||
|         self.assertEqual(res.status_code, 200) | ||||
|         self.assertIn(b"Permission denied", res.content) | ||||
|  | ||||
| @ -121,44 +121,18 @@ class OAuthAuthorizationParams: | ||||
|         redirect_uri = query_dict.get("redirect_uri", "") | ||||
|  | ||||
|         response_type = query_dict.get("response_type", "") | ||||
|         grant_type = None | ||||
|         # Determine which flow to use. | ||||
|         if response_type in [ResponseTypes.CODE]: | ||||
|             grant_type = GrantTypes.AUTHORIZATION_CODE | ||||
|         elif response_type in [ | ||||
|             ResponseTypes.ID_TOKEN, | ||||
|             ResponseTypes.ID_TOKEN_TOKEN, | ||||
|         ]: | ||||
|             grant_type = GrantTypes.IMPLICIT | ||||
|         elif response_type in [ | ||||
|             ResponseTypes.CODE_TOKEN, | ||||
|             ResponseTypes.CODE_ID_TOKEN, | ||||
|             ResponseTypes.CODE_ID_TOKEN_TOKEN, | ||||
|         ]: | ||||
|             grant_type = GrantTypes.HYBRID | ||||
|  | ||||
|         # Grant type validation. | ||||
|         if not grant_type: | ||||
|             LOGGER.warning("Invalid response type", type=response_type) | ||||
|             raise AuthorizeError(redirect_uri, "unsupported_response_type", "", state) | ||||
|  | ||||
|         # Validate and check the response_mode against the predefined dict | ||||
|         # Set to Query or Fragment if not defined in request | ||||
|         response_mode = query_dict.get("response_mode", False) | ||||
|  | ||||
|         if response_mode not in ResponseMode.values: | ||||
|             response_mode = ResponseMode.QUERY | ||||
|  | ||||
|             if grant_type in [GrantTypes.IMPLICIT, GrantTypes.HYBRID]: | ||||
|                 response_mode = ResponseMode.FRAGMENT | ||||
|  | ||||
|         max_age = query_dict.get("max_age") | ||||
|         return OAuthAuthorizationParams( | ||||
|             client_id=query_dict.get("client_id", ""), | ||||
|             redirect_uri=redirect_uri, | ||||
|             response_type=response_type, | ||||
|             response_mode=response_mode, | ||||
|             grant_type=grant_type, | ||||
|             grant_type="", | ||||
|             scope=set(query_dict.get("scope", "").split()), | ||||
|             state=state, | ||||
|             nonce=query_dict.get("nonce"), | ||||
| @ -178,6 +152,7 @@ class OAuthAuthorizationParams: | ||||
|             LOGGER.warning("Invalid client identifier", client_id=self.client_id) | ||||
|             raise ClientIdError(client_id=self.client_id) | ||||
|         self.check_redirect_uri() | ||||
|         self.check_grant() | ||||
|         self.check_scope(github_compat) | ||||
|         self.check_nonce() | ||||
|         self.check_code_challenge() | ||||
| @ -186,6 +161,34 @@ class OAuthAuthorizationParams: | ||||
|                 self.redirect_uri, "request_not_supported", self.grant_type, self.state | ||||
|             ) | ||||
|  | ||||
|     def check_grant(self): | ||||
|         """Check grant""" | ||||
|         # Determine which flow to use. | ||||
|         if self.response_type in [ResponseTypes.CODE]: | ||||
|             self.grant_type = GrantTypes.AUTHORIZATION_CODE | ||||
|         elif self.response_type in [ | ||||
|             ResponseTypes.ID_TOKEN, | ||||
|             ResponseTypes.ID_TOKEN_TOKEN, | ||||
|         ]: | ||||
|             self.grant_type = GrantTypes.IMPLICIT | ||||
|         elif self.response_type in [ | ||||
|             ResponseTypes.CODE_TOKEN, | ||||
|             ResponseTypes.CODE_ID_TOKEN, | ||||
|             ResponseTypes.CODE_ID_TOKEN_TOKEN, | ||||
|         ]: | ||||
|             self.grant_type = GrantTypes.HYBRID | ||||
|  | ||||
|         # Grant type validation. | ||||
|         if not self.grant_type: | ||||
|             LOGGER.warning("Invalid response type", type=self.response_type) | ||||
|             raise AuthorizeError(self.redirect_uri, "unsupported_response_type", "", self.state) | ||||
|  | ||||
|         if self.response_mode not in ResponseMode.values: | ||||
|             self.response_mode = ResponseMode.QUERY | ||||
|  | ||||
|             if self.grant_type in [GrantTypes.IMPLICIT, GrantTypes.HYBRID]: | ||||
|                 self.response_mode = ResponseMode.FRAGMENT | ||||
|  | ||||
|     def check_redirect_uri(self): | ||||
|         """Redirect URI validation.""" | ||||
|         allowed_redirect_urls = self.provider.redirect_uris.split() | ||||
| @ -257,9 +260,9 @@ class OAuthAuthorizationParams: | ||||
|         if SCOPE_OFFLINE_ACCESS in self.scope: | ||||
|             # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess | ||||
|             if PROMPT_CONSENT not in self.prompt: | ||||
|                 raise AuthorizeError( | ||||
|                     self.redirect_uri, "consent_required", self.grant_type, self.state | ||||
|                 ) | ||||
|                 # Instead of ignoring the `offline_access` scope when `prompt` | ||||
|                 # isn't set to `consent`, we set override it ourselves | ||||
|                 self.prompt.add(PROMPT_CONSENT) | ||||
|             if self.response_type not in [ | ||||
|                 ResponseTypes.CODE, | ||||
|                 ResponseTypes.CODE_TOKEN, | ||||
|  | ||||
| @ -12,10 +12,11 @@ from django.views.decorators.csrf import csrf_exempt | ||||
| from rest_framework.throttling import AnonRateThrottle | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.models import Application | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.lib.utils.time import timedelta_from_string | ||||
| from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | ||||
| from authentik.providers.oauth2.views.device_init import QS_KEY_CODE, get_application | ||||
| from authentik.providers.oauth2.views.device_init import QS_KEY_CODE | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
| @ -38,7 +39,9 @@ class DeviceView(View): | ||||
|         ).first() | ||||
|         if not provider: | ||||
|             return HttpResponseBadRequest() | ||||
|         if not get_application(provider): | ||||
|         try: | ||||
|             _ = provider.application | ||||
|         except Application.DoesNotExist: | ||||
|             return HttpResponseBadRequest() | ||||
|         self.provider = provider | ||||
|         self.client_id = client_id | ||||
|  | ||||
| @ -1,11 +1,10 @@ | ||||
| """Device flow views""" | ||||
|  | ||||
| from typing import Optional | ||||
| from typing import Any, Optional | ||||
|  | ||||
| from django.http import HttpRequest, HttpResponse | ||||
| from django.utils.translation import gettext as _ | ||||
| from django.views import View | ||||
| from rest_framework.exceptions import ErrorDetail | ||||
| from rest_framework.exceptions import ValidationError | ||||
| from rest_framework.fields import CharField, IntegerField | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| @ -18,6 +17,7 @@ from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO, | ||||
| from authentik.flows.stage import ChallengeStageView | ||||
| from authentik.flows.views.executor import SESSION_KEY_PLAN | ||||
| from authentik.lib.utils.urls import redirect_with_qs | ||||
| from authentik.policies.views import PolicyAccessView | ||||
| from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | ||||
| from authentik.providers.oauth2.views.device_finish import ( | ||||
|     PLAN_CONTEXT_DEVICE, | ||||
| @ -44,32 +44,36 @@ def get_application(provider: OAuth2Provider) -> Optional[Application]: | ||||
|         return None | ||||
|  | ||||
|  | ||||
| def validate_code(code: int, request: HttpRequest) -> Optional[HttpResponse]: | ||||
|     """Validate user token""" | ||||
|     token = DeviceToken.objects.filter( | ||||
|         user_code=code, | ||||
|     ).first() | ||||
|     if not token: | ||||
|         return None | ||||
| class CodeValidatorView(PolicyAccessView): | ||||
|     """Helper to validate frontside token""" | ||||
|  | ||||
|     app = get_application(token.provider) | ||||
|     if not app: | ||||
|         return None | ||||
|     def __init__(self, code: str, **kwargs: Any) -> None: | ||||
|         super().__init__(**kwargs) | ||||
|         self.code = code | ||||
|  | ||||
|     scope_descriptions = UserInfoView().get_scope_descriptions(token.scope, token.provider) | ||||
|     planner = FlowPlanner(token.provider.authorization_flow) | ||||
|     def resolve_provider_application(self): | ||||
|         self.token = DeviceToken.objects.filter(user_code=self.code).first() | ||||
|         if not self.token: | ||||
|             raise Application.DoesNotExist | ||||
|         self.provider = self.token.provider | ||||
|         self.application = self.token.provider.application | ||||
|  | ||||
|     def get(self, request: HttpRequest, *args, **kwargs): | ||||
|         scope_descriptions = UserInfoView().get_scope_descriptions(self.token.scope, self.provider) | ||||
|         planner = FlowPlanner(self.provider.authorization_flow) | ||||
|         planner.allow_empty_flows = True | ||||
|         planner.use_cache = False | ||||
|         try: | ||||
|             plan = planner.plan( | ||||
|                 request, | ||||
|                 { | ||||
|                     PLAN_CONTEXT_SSO: True, | ||||
|                 PLAN_CONTEXT_APPLICATION: app, | ||||
|                     PLAN_CONTEXT_APPLICATION: self.application, | ||||
|                     # OAuth2 related params | ||||
|                 PLAN_CONTEXT_DEVICE: token, | ||||
|                     PLAN_CONTEXT_DEVICE: self.token, | ||||
|                     # Consent related params | ||||
|                     PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.") | ||||
|                 % {"application": app.name}, | ||||
|                     % {"application": self.application.name}, | ||||
|                     PLAN_CONTEXT_CONSENT_PERMISSIONS: scope_descriptions, | ||||
|                 }, | ||||
|             ) | ||||
| @ -81,11 +85,11 @@ def validate_code(code: int, request: HttpRequest) -> Optional[HttpResponse]: | ||||
|         return redirect_with_qs( | ||||
|             "authentik_core:if-flow", | ||||
|             request.GET, | ||||
|         flow_slug=token.provider.authorization_flow.slug, | ||||
|             flow_slug=self.token.provider.authorization_flow.slug, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class DeviceEntryView(View): | ||||
| class DeviceEntryView(PolicyAccessView): | ||||
|     """View used to initiate the device-code flow, url entered by endusers""" | ||||
|  | ||||
|     def dispatch(self, request: HttpRequest) -> HttpResponse: | ||||
| @ -95,7 +99,9 @@ class DeviceEntryView(View): | ||||
|             LOGGER.info("Brand has no device code flow configured", brand=brand) | ||||
|             return HttpResponse(status=404) | ||||
|         if QS_KEY_CODE in request.GET: | ||||
|             validation = validate_code(request.GET[QS_KEY_CODE], request) | ||||
|             validation = CodeValidatorView(request.GET[QS_KEY_CODE], request=request).dispatch( | ||||
|                 request | ||||
|             ) | ||||
|             if validation: | ||||
|                 return validation | ||||
|             LOGGER.info("Got code from query parameter but no matching token found") | ||||
| @ -130,6 +136,13 @@ class OAuthDeviceCodeChallengeResponse(ChallengeResponse): | ||||
|     code = IntegerField() | ||||
|     component = CharField(default="ak-provider-oauth2-device-code") | ||||
|  | ||||
|     def validate_code(self, code: int) -> HttpResponse | None: | ||||
|         """Validate code and save the returned http response""" | ||||
|         response = CodeValidatorView(code, request=self.stage.request).dispatch(self.stage.request) | ||||
|         if not response: | ||||
|             raise ValidationError(_("Invalid code"), "invalid") | ||||
|         return response | ||||
|  | ||||
|  | ||||
| class OAuthDeviceCodeStage(ChallengeStageView): | ||||
|     """Flow challenge for users to enter device codes""" | ||||
| @ -145,12 +158,4 @@ class OAuthDeviceCodeStage(ChallengeStageView): | ||||
|         ) | ||||
|  | ||||
|     def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: | ||||
|         code = response.validated_data["code"] | ||||
|         validation = validate_code(code, self.request) | ||||
|         if not validation: | ||||
|             response._errors.setdefault("code", []) | ||||
|             response._errors["code"].append(ErrorDetail(_("Invalid code"), "invalid")) | ||||
|             return self.challenge_invalid(response) | ||||
|         # Run cancel to cleanup the current flow | ||||
|         self.executor.cancel() | ||||
|         return validation | ||||
|         return response.validated_data["code"] | ||||
|  | ||||
| @ -101,8 +101,8 @@ class UserInfoView(View): | ||||
|                     value=value, | ||||
|                 ) | ||||
|                 continue | ||||
|             LOGGER.debug("updated scope", scope=scope) | ||||
|             always_merger.merge(final_claims, value) | ||||
|             LOGGER.debug("updated scope", scope=scope) | ||||
|         return final_claims | ||||
|  | ||||
|     def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: | ||||
| @ -121,8 +121,9 @@ class UserInfoView(View): | ||||
|         """Handle GET Requests for UserInfo""" | ||||
|         if not self.token: | ||||
|             return HttpResponseBadRequest() | ||||
|         claims = self.get_claims(self.token.provider, self.token) | ||||
|         claims["sub"] = self.token.id_token.sub | ||||
|         claims = {} | ||||
|         claims.setdefault("sub", self.token.id_token.sub) | ||||
|         claims.update(self.get_claims(self.token.provider, self.token)) | ||||
|         if self.token.id_token.nonce: | ||||
|             claims["nonce"] = self.token.id_token.nonce | ||||
|         response = TokenResponse(claims) | ||||
|  | ||||
| @ -10,7 +10,3 @@ class AuthentikProviderProxyConfig(ManagedAppConfig): | ||||
|     label = "authentik_providers_proxy" | ||||
|     verbose_name = "authentik Providers.Proxy" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_load_providers_proxy_signals(self): | ||||
|         """Load proxy signals""" | ||||
|         self.import_module("authentik.providers.proxy.signals") | ||||
|  | ||||
| @ -22,7 +22,6 @@ from rest_framework.serializers import PrimaryKeyRelatedField, ValidationError | ||||
| from rest_framework.viewsets import ModelViewSet | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.core.api.providers import ProviderSerializer | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import PassiveSerializer, PropertyMappingPreviewSerializer | ||||
| @ -33,6 +32,7 @@ from authentik.providers.saml.processors.assertion import AssertionProcessor | ||||
| from authentik.providers.saml.processors.authn_request_parser import AuthNRequest | ||||
| from authentik.providers.saml.processors.metadata import MetadataProcessor | ||||
| from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser | ||||
| from authentik.rbac.decorators import permission_required | ||||
| from authentik.sources.saml.processors.constants import SAML_BINDING_POST, SAML_BINDING_REDIRECT | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
| @ -10,7 +10,3 @@ class AuthentikProviderSCIMConfig(ManagedAppConfig): | ||||
|     label = "authentik_providers_scim" | ||||
|     verbose_name = "authentik Providers.SCIM" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_load_signals(self): | ||||
|         """Load signals""" | ||||
|         self.import_module("authentik.providers.scim.signals") | ||||
|  | ||||
| @ -15,10 +15,10 @@ from rest_framework.response import Response | ||||
| from rest_framework.serializers import ModelSerializer | ||||
| from rest_framework.viewsets import GenericViewSet | ||||
|  | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.core.api.utils import PassiveSerializer | ||||
| from authentik.policies.event_matcher.models import model_choices | ||||
| from authentik.rbac.api.rbac import PermissionAssignSerializer | ||||
| from authentik.rbac.decorators import permission_required | ||||
| from authentik.rbac.models import Role | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -16,11 +16,11 @@ from rest_framework.response import Response | ||||
| from rest_framework.serializers import ModelSerializer | ||||
| from rest_framework.viewsets import GenericViewSet | ||||
|  | ||||
| from authentik.api.decorators import permission_required | ||||
| from authentik.core.api.groups import GroupMemberSerializer | ||||
| from authentik.core.models import User, UserTypes | ||||
| from authentik.policies.event_matcher.models import model_choices | ||||
| from authentik.rbac.api.rbac import PermissionAssignSerializer | ||||
| from authentik.rbac.decorators import permission_required | ||||
|  | ||||
|  | ||||
| class UserObjectPermissionSerializer(ModelSerializer): | ||||
|  | ||||
| @ -10,7 +10,3 @@ class AuthentikRBACConfig(ManagedAppConfig): | ||||
|     label = "authentik_rbac" | ||||
|     verbose_name = "authentik RBAC" | ||||
|     default = True | ||||
|  | ||||
|     def reconcile_global_load_rbac_signals(self): | ||||
|         """Load rbac signals""" | ||||
|         self.import_module("authentik.rbac.signals") | ||||
|  | ||||
| @ -14,18 +14,23 @@ LOGGER = get_logger() | ||||
| def permission_required(obj_perm: Optional[str] = None, global_perms: Optional[list[str]] = None): | ||||
|     """Check permissions for a single custom action""" | ||||
| 
 | ||||
|     def wrapper_outter(func: Callable): | ||||
|     def _check_obj_perm(self: ModelViewSet, request: Request): | ||||
|         # Check obj_perm both globally and on the specific object | ||||
|         # Having the global permission has higher priority | ||||
|         if request.user.has_perm(obj_perm): | ||||
|             return | ||||
|         obj = self.get_object() | ||||
|         if not request.user.has_perm(obj_perm, obj): | ||||
|             LOGGER.debug("denying access for object", user=request.user, perm=obj_perm, obj=obj) | ||||
|             self.permission_denied(request) | ||||
| 
 | ||||
|     def wrapper_outer(func: Callable): | ||||
|         """Check permissions for a single custom action""" | ||||
| 
 | ||||
|         @wraps(func) | ||||
|         def wrapper(self: ModelViewSet, request: Request, *args, **kwargs) -> Response: | ||||
|             if obj_perm: | ||||
|                 obj = self.get_object() | ||||
|                 if not request.user.has_perm(obj_perm, obj): | ||||
|                     LOGGER.debug( | ||||
|                         "denying access for object", user=request.user, perm=obj_perm, obj=obj | ||||
|                     ) | ||||
|                     return self.permission_denied(request) | ||||
|                 _check_obj_perm(self, request) | ||||
|             if global_perms: | ||||
|                 for other_perm in global_perms: | ||||
|                     if not request.user.has_perm(other_perm): | ||||
| @ -35,4 +40,4 @@ def permission_required(obj_perm: Optional[str] = None, global_perms: Optional[l | ||||
| 
 | ||||
|         return wrapper | ||||
| 
 | ||||
|     return wrapper_outter | ||||
|     return wrapper_outer | ||||
							
								
								
									
										58
									
								
								authentik/rbac/tests/test_decorators.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								authentik/rbac/tests/test_decorators.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,58 @@ | ||||
| """test decorators api""" | ||||
|  | ||||
| from django.urls import reverse | ||||
| from guardian.shortcuts import assign_perm | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import Application | ||||
| from authentik.core.tests.utils import create_test_user | ||||
| from authentik.lib.generators import generate_id | ||||
|  | ||||
|  | ||||
| class TestAPIDecorators(APITestCase): | ||||
|     """test decorators api""" | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         super().setUp() | ||||
|         self.user = create_test_user() | ||||
|  | ||||
|     def test_obj_perm_denied(self): | ||||
|         """Test object perm denied""" | ||||
|         self.client.force_login(self.user) | ||||
|         app = Application.objects.create(name=generate_id(), slug=generate_id()) | ||||
|         response = self.client.get( | ||||
|             reverse("authentik_api:application-metrics", kwargs={"slug": app.slug}) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 403) | ||||
|  | ||||
|     def test_obj_perm_global(self): | ||||
|         """Test object perm successful (global)""" | ||||
|         assign_perm("authentik_core.view_application", self.user) | ||||
|         assign_perm("authentik_events.view_event", self.user) | ||||
|         self.client.force_login(self.user) | ||||
|         app = Application.objects.create(name=generate_id(), slug=generate_id()) | ||||
|         response = self.client.get( | ||||
|             reverse("authentik_api:application-metrics", kwargs={"slug": app.slug}) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|  | ||||
|     def test_obj_perm_scoped(self): | ||||
|         """Test object perm successful (scoped)""" | ||||
|         assign_perm("authentik_events.view_event", self.user) | ||||
|         app = Application.objects.create(name=generate_id(), slug=generate_id()) | ||||
|         assign_perm("authentik_core.view_application", self.user, app) | ||||
|         self.client.force_login(self.user) | ||||
|         response = self.client.get( | ||||
|             reverse("authentik_api:application-metrics", kwargs={"slug": app.slug}) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|  | ||||
|     def test_other_perm_denied(self): | ||||
|         """Test other perm denied""" | ||||
|         self.client.force_login(self.user) | ||||
|         app = Application.objects.create(name=generate_id(), slug=generate_id()) | ||||
|         assign_perm("authentik_core.view_application", self.user, app) | ||||
|         response = self.client.get( | ||||
|             reverse("authentik_api:application-metrics", kwargs={"slug": app.slug}) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 403) | ||||
| @ -91,13 +91,10 @@ def _get_startup_tasks_default_tenant() -> list[Callable]: | ||||
| def _get_startup_tasks_all_tenants() -> list[Callable]: | ||||
|     """Get all tasks to be run on startup for all tenants""" | ||||
|     from authentik.admin.tasks import clear_update_notifications | ||||
|     from authentik.outposts.tasks import outpost_connection_discovery, outpost_controller_all | ||||
|     from authentik.providers.proxy.tasks import proxy_set_defaults | ||||
|  | ||||
|     return [ | ||||
|         clear_update_notifications, | ||||
|         outpost_connection_discovery, | ||||
|         outpost_controller_all, | ||||
|         proxy_set_defaults, | ||||
|     ] | ||||
|  | ||||
|  | ||||
| @ -7,6 +7,8 @@ from psycopg import connect | ||||
|  | ||||
| from authentik.lib.config import CONFIG | ||||
|  | ||||
| QUERY = """SELECT id FROM public.authentik_install_id ORDER BY id LIMIT 1;""" | ||||
|  | ||||
|  | ||||
| @lru_cache | ||||
| def get_install_id() -> str: | ||||
| @ -18,7 +20,7 @@ def get_install_id() -> str: | ||||
|     if settings.TEST: | ||||
|         return str(uuid4()) | ||||
|     with connection.cursor() as cursor: | ||||
|         cursor.execute("SELECT id FROM public.authentik_install_id LIMIT 1;") | ||||
|         cursor.execute(QUERY) | ||||
|         return cursor.fetchone()[0] | ||||
|  | ||||
|  | ||||
| @ -38,5 +40,5 @@ def get_install_id_raw(): | ||||
|         sslkey=CONFIG.get("postgresql.sslkey"), | ||||
|     ) | ||||
|     cursor = conn.cursor() | ||||
|     cursor.execute("SELECT id FROM public.authentik_install_id LIMIT 1;") | ||||
|     cursor.execute(QUERY) | ||||
|     return cursor.fetchone()[0] | ||||
|  | ||||
| @ -1,8 +1,7 @@ | ||||
| """Dynamically set SameSite depending if the upstream connection is TLS or not""" | ||||
|  | ||||
| from hashlib import sha512 | ||||
| from time import time | ||||
| from timeit import default_timer | ||||
| from time import perf_counter, time | ||||
| from typing import Any, Callable, Optional | ||||
|  | ||||
| from django.conf import settings | ||||
| @ -294,14 +293,14 @@ class LoggingMiddleware: | ||||
|         self.get_response = get_response | ||||
|  | ||||
|     def __call__(self, request: HttpRequest) -> HttpResponse: | ||||
|         start = default_timer() | ||||
|         start = perf_counter() | ||||
|         response = self.get_response(request) | ||||
|         status_code = response.status_code | ||||
|         kwargs = { | ||||
|             "request_id": getattr(request, "request_id", None), | ||||
|         } | ||||
|         kwargs.update(getattr(response, "ak_context", {})) | ||||
|         self.log(request, status_code, int((default_timer() - start) * 1000), **kwargs) | ||||
|         self.log(request, status_code, int((perf_counter() - start) * 1000), **kwargs) | ||||
|         return response | ||||
|  | ||||
|     def log(self, request: HttpRequest, status_code: int, runtime: int, **kwargs): | ||||
|  | ||||
| @ -69,7 +69,6 @@ TENANT_APPS = [ | ||||
|     "authentik.admin", | ||||
|     "authentik.api", | ||||
|     "authentik.crypto", | ||||
|     "authentik.events", | ||||
|     "authentik.flows", | ||||
|     "authentik.outposts", | ||||
|     "authentik.policies.dummy", | ||||
| @ -482,13 +481,6 @@ def _update_settings(app_path: str): | ||||
|         pass | ||||
|  | ||||
|  | ||||
| # Load subapps's settings | ||||
| for _app in set(SHARED_APPS + TENANT_APPS): | ||||
|     if not _app.startswith("authentik"): | ||||
|         continue | ||||
|     _update_settings(f"{_app}.settings") | ||||
| _update_settings("data.user_settings") | ||||
|  | ||||
| if DEBUG: | ||||
|     CELERY["task_always_eager"] = True | ||||
|     os.environ[ENV_GIT_HASH_KEY] = "dev" | ||||
| @ -509,5 +501,17 @@ try: | ||||
| except ImportError: | ||||
|     pass | ||||
|  | ||||
| # Import events after other apps since it relies on tasks and other things from all apps | ||||
| # being imported for @prefill_task | ||||
| TENANT_APPS.append("authentik.events") | ||||
|  | ||||
|  | ||||
| # Load subapps's settings | ||||
| for _app in set(SHARED_APPS + TENANT_APPS): | ||||
|     if not _app.startswith("authentik"): | ||||
|         continue | ||||
|     _update_settings(f"{_app}.settings") | ||||
| _update_settings("data.user_settings") | ||||
|  | ||||
| SHARED_APPS = list(OrderedDict.fromkeys(SHARED_APPS + TENANT_APPS)) | ||||
| INSTALLED_APPS = list(OrderedDict.fromkeys(SHARED_APPS + TENANT_APPS)) | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	