Compare commits
	
		
			2 Commits
		
	
	
		
			enterprise
			...
			providers/
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 7549a6b83d | |||
| bb45b714e2 | 
| @ -1,5 +1,5 @@ | |||||||
| [bumpversion] | [bumpversion] | ||||||
| current_version = 2025.6.3 | current_version = 2025.6.1 | ||||||
| tag = True | tag = True | ||||||
| commit = True | commit = True | ||||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | ||||||
| @ -21,8 +21,6 @@ optional_value = final | |||||||
|  |  | ||||||
| [bumpversion:file:package.json] | [bumpversion:file:package.json] | ||||||
|  |  | ||||||
| [bumpversion:file:package-lock.json] |  | ||||||
|  |  | ||||||
| [bumpversion:file:docker-compose.yml] | [bumpversion:file:docker-compose.yml] | ||||||
|  |  | ||||||
| [bumpversion:file:schema.yml] | [bumpversion:file:schema.yml] | ||||||
| @ -33,4 +31,6 @@ optional_value = final | |||||||
|  |  | ||||||
| [bumpversion:file:internal/constants/constants.go] | [bumpversion:file:internal/constants/constants.go] | ||||||
|  |  | ||||||
|  | [bumpversion:file:web/src/common/constants.ts] | ||||||
|  |  | ||||||
| [bumpversion:file:lifecycle/aws/template.yaml] | [bumpversion:file:lifecycle/aws/template.yaml] | ||||||
|  | |||||||
| @ -7,9 +7,6 @@ charset = utf-8 | |||||||
| trim_trailing_whitespace = true | trim_trailing_whitespace = true | ||||||
| insert_final_newline = true | insert_final_newline = true | ||||||
|  |  | ||||||
| [*.toml] |  | ||||||
| indent_size = 2 |  | ||||||
|  |  | ||||||
| [*.html] | [*.html] | ||||||
| indent_size = 2 | indent_size = 2 | ||||||
|  |  | ||||||
|  | |||||||
| @ -38,8 +38,6 @@ jobs: | |||||||
|       # Needed for attestation |       # Needed for attestation | ||||||
|       id-token: write |       id-token: write | ||||||
|       attestations: write |       attestations: write | ||||||
|       # Needed for checkout |  | ||||||
|       contents: read |  | ||||||
|     steps: |     steps: | ||||||
|       - uses: actions/checkout@v4 |       - uses: actions/checkout@v4 | ||||||
|       - uses: docker/setup-qemu-action@v3.6.0 |       - uses: docker/setup-qemu-action@v3.6.0 | ||||||
|  | |||||||
							
								
								
									
										3
									
								
								.github/workflows/ci-main-daily.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/ci-main-daily.yml
									
									
									
									
										vendored
									
									
								
							| @ -9,15 +9,14 @@ on: | |||||||
|  |  | ||||||
| jobs: | jobs: | ||||||
|   test-container: |   test-container: | ||||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} |  | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     strategy: |     strategy: | ||||||
|       fail-fast: false |       fail-fast: false | ||||||
|       matrix: |       matrix: | ||||||
|         version: |         version: | ||||||
|           - docs |           - docs | ||||||
|           - version-2025-4 |  | ||||||
|           - version-2025-2 |           - version-2025-2 | ||||||
|  |           - version-2024-12 | ||||||
|     steps: |     steps: | ||||||
|       - uses: actions/checkout@v4 |       - uses: actions/checkout@v4 | ||||||
|       - run: | |       - run: | | ||||||
|  | |||||||
							
								
								
									
										6
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -202,7 +202,7 @@ jobs: | |||||||
|         uses: actions/cache@v4 |         uses: actions/cache@v4 | ||||||
|         with: |         with: | ||||||
|           path: web/dist |           path: web/dist | ||||||
|           key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b |           key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b | ||||||
|       - name: prepare web ui |       - name: prepare web ui | ||||||
|         if: steps.cache-web.outputs.cache-hit != 'true' |         if: steps.cache-web.outputs.cache-hit != 'true' | ||||||
|         working-directory: web |         working-directory: web | ||||||
| @ -247,13 +247,11 @@ jobs: | |||||||
|       # Needed for attestation |       # Needed for attestation | ||||||
|       id-token: write |       id-token: write | ||||||
|       attestations: write |       attestations: write | ||||||
|       # Needed for checkout |  | ||||||
|       contents: read |  | ||||||
|     needs: ci-core-mark |     needs: ci-core-mark | ||||||
|     uses: ./.github/workflows/_reusable-docker-build.yaml |     uses: ./.github/workflows/_reusable-docker-build.yaml | ||||||
|     secrets: inherit |     secrets: inherit | ||||||
|     with: |     with: | ||||||
|       image_name: ${{ github.repository == 'goauthentik/authentik-internal' && 'ghcr.io/goauthentik/internal-server' || 'ghcr.io/goauthentik/dev-server' }} |       image_name: ghcr.io/goauthentik/dev-server | ||||||
|       release: false |       release: false | ||||||
|   pr-comment: |   pr-comment: | ||||||
|     needs: |     needs: | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							| @ -59,7 +59,6 @@ jobs: | |||||||
|         with: |         with: | ||||||
|           jobs: ${{ toJSON(needs) }} |           jobs: ${{ toJSON(needs) }} | ||||||
|   build-container: |   build-container: | ||||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} |  | ||||||
|     timeout-minutes: 120 |     timeout-minutes: 120 | ||||||
|     needs: |     needs: | ||||||
|       - ci-outpost-mark |       - ci-outpost-mark | ||||||
|  | |||||||
							
								
								
									
										24
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										24
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							| @ -41,29 +41,7 @@ jobs: | |||||||
|       - name: test |       - name: test | ||||||
|         working-directory: website/ |         working-directory: website/ | ||||||
|         run: npm test |         run: npm test | ||||||
|   build: |  | ||||||
|     runs-on: ubuntu-latest |  | ||||||
|     name: ${{ matrix.job }} |  | ||||||
|     strategy: |  | ||||||
|       fail-fast: false |  | ||||||
|       matrix: |  | ||||||
|         job: |  | ||||||
|           - build |  | ||||||
|           - build:integrations |  | ||||||
|     steps: |  | ||||||
|       - uses: actions/checkout@v4 |  | ||||||
|       - uses: actions/setup-node@v4 |  | ||||||
|         with: |  | ||||||
|           node-version-file: website/package.json |  | ||||||
|           cache: "npm" |  | ||||||
|           cache-dependency-path: website/package-lock.json |  | ||||||
|       - working-directory: website/ |  | ||||||
|         run: npm ci |  | ||||||
|       - name: build |  | ||||||
|         working-directory: website/ |  | ||||||
|         run: npm run ${{ matrix.job }} |  | ||||||
|   build-container: |   build-container: | ||||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} |  | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     permissions: |     permissions: | ||||||
|       # Needed to upload container images to ghcr.io |       # Needed to upload container images to ghcr.io | ||||||
| @ -116,11 +94,9 @@ jobs: | |||||||
|     needs: |     needs: | ||||||
|       - lint |       - lint | ||||||
|       - test |       - test | ||||||
|       - build |  | ||||||
|       - build-container |       - build-container | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
|       - uses: re-actors/alls-green@release/v1 |       - uses: re-actors/alls-green@release/v1 | ||||||
|         with: |         with: | ||||||
|           jobs: ${{ toJSON(needs) }} |           jobs: ${{ toJSON(needs) }} | ||||||
|           allowed-skips: ${{ github.repository == 'goauthentik/authentik-internal' && 'build-container' || '[]' }} |  | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/workflows/codeql-analysis.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/codeql-analysis.yml
									
									
									
									
										vendored
									
									
								
							| @ -2,7 +2,7 @@ name: "CodeQL" | |||||||
|  |  | ||||||
| on: | on: | ||||||
|   push: |   push: | ||||||
|     branches: [main, next, version*] |     branches: [main, "*", next, version*] | ||||||
|   pull_request: |   pull_request: | ||||||
|     branches: [main] |     branches: [main] | ||||||
|   schedule: |   schedule: | ||||||
|  | |||||||
							
								
								
									
										21
									
								
								.github/workflows/repo-mirror-cleanup.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										21
									
								
								.github/workflows/repo-mirror-cleanup.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,21 +0,0 @@ | |||||||
| name: "authentik-repo-mirror-cleanup" |  | ||||||
|  |  | ||||||
| on: |  | ||||||
|   workflow_dispatch: |  | ||||||
|  |  | ||||||
| jobs: |  | ||||||
|   to_internal: |  | ||||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} |  | ||||||
|     runs-on: ubuntu-latest |  | ||||||
|     steps: |  | ||||||
|       - uses: actions/checkout@v4 |  | ||||||
|         with: |  | ||||||
|           fetch-depth: 0 |  | ||||||
|       - if: ${{ env.MIRROR_KEY != '' }} |  | ||||||
|         uses: BeryJu/repository-mirroring-action@5cf300935bc2e068f73ea69bcc411a8a997208eb |  | ||||||
|         with: |  | ||||||
|           target_repo_url: git@github.com:goauthentik/authentik-internal.git |  | ||||||
|           ssh_private_key: ${{ secrets.GH_MIRROR_KEY }} |  | ||||||
|           args: --tags --force --prune |  | ||||||
|         env: |  | ||||||
|           MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }} |  | ||||||
							
								
								
									
										9
									
								
								.github/workflows/repo-mirror.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								.github/workflows/repo-mirror.yml
									
									
									
									
										vendored
									
									
								
							| @ -11,10 +11,11 @@ jobs: | |||||||
|         with: |         with: | ||||||
|           fetch-depth: 0 |           fetch-depth: 0 | ||||||
|       - if: ${{ env.MIRROR_KEY != '' }} |       - if: ${{ env.MIRROR_KEY != '' }} | ||||||
|         uses: BeryJu/repository-mirroring-action@5cf300935bc2e068f73ea69bcc411a8a997208eb |         uses: pixta-dev/repository-mirroring-action@v1 | ||||||
|         with: |         with: | ||||||
|           target_repo_url: git@github.com:goauthentik/authentik-internal.git |           target_repo_url: | ||||||
|           ssh_private_key: ${{ secrets.GH_MIRROR_KEY }} |             git@github.com:goauthentik/authentik-internal.git | ||||||
|           args: --tags --force |           ssh_private_key: | ||||||
|  |             ${{ secrets.GH_MIRROR_KEY }} | ||||||
|         env: |         env: | ||||||
|           MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }} |           MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }} | ||||||
|  | |||||||
| @ -16,7 +16,6 @@ env: | |||||||
|  |  | ||||||
| jobs: | jobs: | ||||||
|   compile: |   compile: | ||||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} |  | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
|       - id: generate_token |       - id: generate_token | ||||||
|  | |||||||
							
								
								
									
										4
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
								
							| @ -6,15 +6,13 @@ | |||||||
|         "!Context scalar", |         "!Context scalar", | ||||||
|         "!Enumerate sequence", |         "!Enumerate sequence", | ||||||
|         "!Env scalar", |         "!Env scalar", | ||||||
|         "!Env sequence", |  | ||||||
|         "!Find sequence", |         "!Find sequence", | ||||||
|         "!Format sequence", |         "!Format sequence", | ||||||
|         "!If sequence", |         "!If sequence", | ||||||
|         "!Index scalar", |         "!Index scalar", | ||||||
|         "!KeyOf scalar", |         "!KeyOf scalar", | ||||||
|         "!Value scalar", |         "!Value scalar", | ||||||
|         "!AtIndex scalar", |         "!AtIndex scalar" | ||||||
|         "!ParseJSON scalar" |  | ||||||
|     ], |     ], | ||||||
|     "typescript.preferences.importModuleSpecifier": "non-relative", |     "typescript.preferences.importModuleSpecifier": "non-relative", | ||||||
|     "typescript.preferences.importModuleSpecifierEnding": "index", |     "typescript.preferences.importModuleSpecifierEnding": "index", | ||||||
|  | |||||||
| @ -75,9 +75,9 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \ | |||||||
|     /bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" |     /bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" | ||||||
|  |  | ||||||
| # Stage 4: Download uv | # Stage 4: Download uv | ||||||
| FROM ghcr.io/astral-sh/uv:0.7.17 AS uv | FROM ghcr.io/astral-sh/uv:0.7.13 AS uv | ||||||
| # Stage 5: Base python image | # Stage 5: Base python image | ||||||
| FROM ghcr.io/goauthentik/fips-python:3.13.5-slim-bookworm-fips AS python-base | FROM ghcr.io/goauthentik/fips-python:3.13.4-slim-bookworm-fips AS python-base | ||||||
|  |  | ||||||
| ENV VENV_PATH="/ak-root/.venv" \ | ENV VENV_PATH="/ak-root/.venv" \ | ||||||
|     PATH="/lifecycle:/ak-root/.venv/bin:$PATH" \ |     PATH="/lifecycle:/ak-root/.venv/bin:$PATH" \ | ||||||
|  | |||||||
							
								
								
									
										10
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								Makefile
									
									
									
									
									
								
							| @ -86,10 +86,6 @@ dev-create-db: | |||||||
|  |  | ||||||
| dev-reset: dev-drop-db dev-create-db migrate  ## Drop and restore the Authentik PostgreSQL instance to a "fresh install" state. | dev-reset: dev-drop-db dev-create-db migrate  ## Drop and restore the Authentik PostgreSQL instance to a "fresh install" state. | ||||||
|  |  | ||||||
| update-test-mmdb:  ## Update test GeoIP and ASN Databases |  | ||||||
| 	curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-ASN-Test.mmdb -o ${PWD}/tests/GeoLite2-ASN-Test.mmdb |  | ||||||
| 	curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-City-Test.mmdb -o ${PWD}/tests/GeoLite2-City-Test.mmdb |  | ||||||
|  |  | ||||||
| ######################### | ######################### | ||||||
| ## API Schema | ## API Schema | ||||||
| ######################### | ######################### | ||||||
| @ -150,9 +146,9 @@ gen-client-ts: gen-clean-ts  ## Build and install the authentik API for Typescri | |||||||
| 		--additional-properties=npmVersion=${NPM_VERSION} \ | 		--additional-properties=npmVersion=${NPM_VERSION} \ | ||||||
| 		--git-repo-id authentik \ | 		--git-repo-id authentik \ | ||||||
| 		--git-user-id goauthentik | 		--git-user-id goauthentik | ||||||
|  | 	mkdir -p web/node_modules/@goauthentik/api | ||||||
| 	cd ${PWD}/${GEN_API_TS} && npm link | 	cd ${PWD}/${GEN_API_TS} && npm i | ||||||
| 	cd ${PWD}/web && npm link @goauthentik/api | 	\cp -rf ${PWD}/${GEN_API_TS}/* web/node_modules/@goauthentik/api | ||||||
|  |  | ||||||
| gen-client-py: gen-clean-py ## Build and install the authentik API for Python | gen-client-py: gen-clean-py ## Build and install the authentik API for Python | ||||||
| 	docker run \ | 	docker run \ | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
|  |  | ||||||
| from os import environ | from os import environ | ||||||
|  |  | ||||||
| __version__ = "2025.6.3" | __version__ = "2025.6.1" | ||||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -37,7 +37,6 @@ entries: | |||||||
|     - attrs: |     - attrs: | ||||||
|           attributes: |           attributes: | ||||||
|               env_null: !Env [bar-baz, null] |               env_null: !Env [bar-baz, null] | ||||||
|               json_parse: !ParseJSON '{"foo": "bar"}' |  | ||||||
|               policy_pk1: |               policy_pk1: | ||||||
|                   !Format [ |                   !Format [ | ||||||
|                       "%s-%s", |                       "%s-%s", | ||||||
|  | |||||||
| @ -35,6 +35,6 @@ def blueprint_tester(file_name: Path) -> Callable: | |||||||
|  |  | ||||||
|  |  | ||||||
| for blueprint_file in Path("blueprints/").glob("**/*.yaml"): | for blueprint_file in Path("blueprints/").glob("**/*.yaml"): | ||||||
|     if "local" in str(blueprint_file) or "testing" in str(blueprint_file): |     if "local" in str(blueprint_file): | ||||||
|         continue |         continue | ||||||
|     setattr(TestPackaged, f"test_blueprint_{blueprint_file}", blueprint_tester(blueprint_file)) |     setattr(TestPackaged, f"test_blueprint_{blueprint_file}", blueprint_tester(blueprint_file)) | ||||||
|  | |||||||
| @ -5,6 +5,7 @@ from collections.abc import Callable | |||||||
| from django.apps import apps | from django.apps import apps | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
|  |  | ||||||
|  | from authentik.blueprints.v1.importer import is_model_allowed | ||||||
| from authentik.lib.models import SerializerModel | from authentik.lib.models import SerializerModel | ||||||
| from authentik.providers.oauth2.models import RefreshToken | from authentik.providers.oauth2.models import RefreshToken | ||||||
|  |  | ||||||
| @ -21,13 +22,10 @@ def serializer_tester_factory(test_model: type[SerializerModel]) -> Callable: | |||||||
|             return |             return | ||||||
|         model_class = test_model() |         model_class = test_model() | ||||||
|         self.assertTrue(isinstance(model_class, SerializerModel)) |         self.assertTrue(isinstance(model_class, SerializerModel)) | ||||||
|         # Models that have subclasses don't have to have a serializer |  | ||||||
|         if len(test_model.__subclasses__()) > 0: |  | ||||||
|             return |  | ||||||
|         self.assertIsNotNone(model_class.serializer) |         self.assertIsNotNone(model_class.serializer) | ||||||
|         if model_class.serializer.Meta().model == RefreshToken: |         if model_class.serializer.Meta().model == RefreshToken: | ||||||
|             return |             return | ||||||
|         self.assertTrue(issubclass(test_model, model_class.serializer.Meta().model)) |         self.assertEqual(model_class.serializer.Meta().model, test_model) | ||||||
|  |  | ||||||
|     return tester |     return tester | ||||||
|  |  | ||||||
| @ -36,6 +34,6 @@ for app in apps.get_app_configs(): | |||||||
|     if not app.label.startswith("authentik"): |     if not app.label.startswith("authentik"): | ||||||
|         continue |         continue | ||||||
|     for model in app.get_models(): |     for model in app.get_models(): | ||||||
|         if not issubclass(model, SerializerModel): |         if not is_model_allowed(model): | ||||||
|             continue |             continue | ||||||
|         setattr(TestModels, f"test_{app.label}_{model.__name__}", serializer_tester_factory(model)) |         setattr(TestModels, f"test_{app.label}_{model.__name__}", serializer_tester_factory(model)) | ||||||
|  | |||||||
| @ -215,7 +215,6 @@ class TestBlueprintsV1(TransactionTestCase): | |||||||
|                     }, |                     }, | ||||||
|                     "nested_context": "context-nested-value", |                     "nested_context": "context-nested-value", | ||||||
|                     "env_null": None, |                     "env_null": None, | ||||||
|                     "json_parse": {"foo": "bar"}, |  | ||||||
|                     "at_index_sequence": "foo", |                     "at_index_sequence": "foo", | ||||||
|                     "at_index_sequence_default": "non existent", |                     "at_index_sequence_default": "non existent", | ||||||
|                     "at_index_mapping": 2, |                     "at_index_mapping": 2, | ||||||
|  | |||||||
| @ -6,7 +6,6 @@ from copy import copy | |||||||
| from dataclasses import asdict, dataclass, field, is_dataclass | from dataclasses import asdict, dataclass, field, is_dataclass | ||||||
| from enum import Enum | from enum import Enum | ||||||
| from functools import reduce | from functools import reduce | ||||||
| from json import JSONDecodeError, loads |  | ||||||
| from operator import ixor | from operator import ixor | ||||||
| from os import getenv | from os import getenv | ||||||
| from typing import Any, Literal, Union | from typing import Any, Literal, Union | ||||||
| @ -292,22 +291,6 @@ class Context(YAMLTag): | |||||||
|         return value |         return value | ||||||
|  |  | ||||||
|  |  | ||||||
| class ParseJSON(YAMLTag): |  | ||||||
|     """Parse JSON from context/env/etc value""" |  | ||||||
|  |  | ||||||
|     raw: str |  | ||||||
|  |  | ||||||
|     def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None: |  | ||||||
|         super().__init__() |  | ||||||
|         self.raw = node.value |  | ||||||
|  |  | ||||||
|     def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: |  | ||||||
|         try: |  | ||||||
|             return loads(self.raw) |  | ||||||
|         except JSONDecodeError as exc: |  | ||||||
|             raise EntryInvalidError.from_entry(exc, entry) from exc |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Format(YAMLTag): | class Format(YAMLTag): | ||||||
|     """Format a string""" |     """Format a string""" | ||||||
|  |  | ||||||
| @ -683,7 +666,6 @@ class BlueprintLoader(SafeLoader): | |||||||
|         self.add_constructor("!Value", Value) |         self.add_constructor("!Value", Value) | ||||||
|         self.add_constructor("!Index", Index) |         self.add_constructor("!Index", Index) | ||||||
|         self.add_constructor("!AtIndex", AtIndex) |         self.add_constructor("!AtIndex", AtIndex) | ||||||
|         self.add_constructor("!ParseJSON", ParseJSON) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class EntryInvalidError(SentryIgnoredException): | class EntryInvalidError(SentryIgnoredException): | ||||||
|  | |||||||
| @ -43,7 +43,6 @@ from authentik.core.models import ( | |||||||
| ) | ) | ||||||
| from authentik.enterprise.license import LicenseKey | from authentik.enterprise.license import LicenseKey | ||||||
| from authentik.enterprise.models import LicenseUsage | from authentik.enterprise.models import LicenseUsage | ||||||
| from authentik.enterprise.providers.apple_psso.models import AppleNonce |  | ||||||
| from authentik.enterprise.providers.google_workspace.models import ( | from authentik.enterprise.providers.google_workspace.models import ( | ||||||
|     GoogleWorkspaceProviderGroup, |     GoogleWorkspaceProviderGroup, | ||||||
|     GoogleWorkspaceProviderUser, |     GoogleWorkspaceProviderUser, | ||||||
| @ -136,7 +135,6 @@ def excluded_models() -> list[type[Model]]: | |||||||
|         EndpointDeviceConnection, |         EndpointDeviceConnection, | ||||||
|         DeviceToken, |         DeviceToken, | ||||||
|         StreamEvent, |         StreamEvent, | ||||||
|         AppleNonce, |  | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,6 +1,8 @@ | |||||||
| """Authenticator Devices API Views""" | """Authenticator Devices API Views""" | ||||||
|  |  | ||||||
| from drf_spectacular.utils import extend_schema | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from drf_spectacular.types import OpenApiTypes | ||||||
|  | from drf_spectacular.utils import OpenApiParameter, extend_schema | ||||||
| from guardian.shortcuts import get_objects_for_user | from guardian.shortcuts import get_objects_for_user | ||||||
| from rest_framework.fields import ( | from rest_framework.fields import ( | ||||||
|     BooleanField, |     BooleanField, | ||||||
| @ -13,7 +15,6 @@ from rest_framework.request import Request | |||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
| from rest_framework.viewsets import ViewSet | from rest_framework.viewsets import ViewSet | ||||||
|  |  | ||||||
| from authentik.core.api.users import ParamUserSerializer |  | ||||||
| from authentik.core.api.utils import MetaNameSerializer | from authentik.core.api.utils import MetaNameSerializer | ||||||
| from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import EndpointDevice | from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import EndpointDevice | ||||||
| from authentik.stages.authenticator import device_classes, devices_for_user | from authentik.stages.authenticator import device_classes, devices_for_user | ||||||
| @ -22,7 +23,7 @@ from authentik.stages.authenticator_webauthn.models import WebAuthnDevice | |||||||
|  |  | ||||||
|  |  | ||||||
| class DeviceSerializer(MetaNameSerializer): | class DeviceSerializer(MetaNameSerializer): | ||||||
|     """Serializer for authenticator devices""" |     """Serializer for Duo authenticator devices""" | ||||||
|  |  | ||||||
|     pk = CharField() |     pk = CharField() | ||||||
|     name = CharField() |     name = CharField() | ||||||
| @ -32,27 +33,22 @@ class DeviceSerializer(MetaNameSerializer): | |||||||
|     last_updated = DateTimeField(read_only=True) |     last_updated = DateTimeField(read_only=True) | ||||||
|     last_used = DateTimeField(read_only=True, allow_null=True) |     last_used = DateTimeField(read_only=True, allow_null=True) | ||||||
|     extra_description = SerializerMethodField() |     extra_description = SerializerMethodField() | ||||||
|     external_id = SerializerMethodField() |  | ||||||
|  |  | ||||||
|     def get_type(self, instance: Device) -> str: |     def get_type(self, instance: Device) -> str: | ||||||
|         """Get type of device""" |         """Get type of device""" | ||||||
|         return instance._meta.label |         return instance._meta.label | ||||||
|  |  | ||||||
|     def get_extra_description(self, instance: Device) -> str | None: |     def get_extra_description(self, instance: Device) -> str: | ||||||
|         """Get extra description""" |         """Get extra description""" | ||||||
|         if isinstance(instance, WebAuthnDevice): |         if isinstance(instance, WebAuthnDevice): | ||||||
|             return instance.device_type.description if instance.device_type else None |             return ( | ||||||
|  |                 instance.device_type.description | ||||||
|  |                 if instance.device_type | ||||||
|  |                 else _("Extra description not available") | ||||||
|  |             ) | ||||||
|         if isinstance(instance, EndpointDevice): |         if isinstance(instance, EndpointDevice): | ||||||
|             return instance.data.get("deviceSignals", {}).get("deviceModel") |             return instance.data.get("deviceSignals", {}).get("deviceModel") | ||||||
|         return None |         return "" | ||||||
|  |  | ||||||
|     def get_external_id(self, instance: Device) -> str | None: |  | ||||||
|         """Get external Device ID""" |  | ||||||
|         if isinstance(instance, WebAuthnDevice): |  | ||||||
|             return instance.device_type.aaguid if instance.device_type else None |  | ||||||
|         if isinstance(instance, EndpointDevice): |  | ||||||
|             return instance.data.get("deviceSignals", {}).get("deviceModel") |  | ||||||
|         return None |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DeviceViewSet(ViewSet): | class DeviceViewSet(ViewSet): | ||||||
| @ -61,6 +57,7 @@ class DeviceViewSet(ViewSet): | |||||||
|     serializer_class = DeviceSerializer |     serializer_class = DeviceSerializer | ||||||
|     permission_classes = [IsAuthenticated] |     permission_classes = [IsAuthenticated] | ||||||
|  |  | ||||||
|  |     @extend_schema(responses={200: DeviceSerializer(many=True)}) | ||||||
|     def list(self, request: Request) -> Response: |     def list(self, request: Request) -> Response: | ||||||
|         """Get all devices for current user""" |         """Get all devices for current user""" | ||||||
|         devices = devices_for_user(request.user) |         devices = devices_for_user(request.user) | ||||||
| @ -82,11 +79,18 @@ class AdminDeviceViewSet(ViewSet): | |||||||
|             yield from device_set |             yield from device_set | ||||||
|  |  | ||||||
|     @extend_schema( |     @extend_schema( | ||||||
|         parameters=[ParamUserSerializer], |         parameters=[ | ||||||
|  |             OpenApiParameter( | ||||||
|  |                 name="user", | ||||||
|  |                 location=OpenApiParameter.QUERY, | ||||||
|  |                 type=OpenApiTypes.INT, | ||||||
|  |             ) | ||||||
|  |         ], | ||||||
|         responses={200: DeviceSerializer(many=True)}, |         responses={200: DeviceSerializer(many=True)}, | ||||||
|     ) |     ) | ||||||
|     def list(self, request: Request) -> Response: |     def list(self, request: Request) -> Response: | ||||||
|         """Get all devices for current user""" |         """Get all devices for current user""" | ||||||
|         args = ParamUserSerializer(data=request.query_params) |         kwargs = {} | ||||||
|         args.is_valid(raise_exception=True) |         if "user" in request.query_params: | ||||||
|         return Response(DeviceSerializer(self.get_devices(**args.validated_data), many=True).data) |             kwargs = {"user": request.query_params["user"]} | ||||||
|  |         return Response(DeviceSerializer(self.get_devices(**kwargs), many=True).data) | ||||||
|  | |||||||
| @ -90,12 +90,6 @@ from authentik.stages.email.utils import TemplateEmailMessage | |||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| class ParamUserSerializer(PassiveSerializer): |  | ||||||
|     """Partial serializer for query parameters to select a user""" |  | ||||||
|  |  | ||||||
|     user = PrimaryKeyRelatedField(queryset=User.objects.all().exclude_anonymous(), required=False) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class UserGroupSerializer(ModelSerializer): | class UserGroupSerializer(ModelSerializer): | ||||||
|     """Simplified Group Serializer for user's groups""" |     """Simplified Group Serializer for user's groups""" | ||||||
|  |  | ||||||
| @ -392,23 +386,8 @@ class UserViewSet(UsedByMixin, ModelViewSet): | |||||||
|     queryset = User.objects.none() |     queryset = User.objects.none() | ||||||
|     ordering = ["username"] |     ordering = ["username"] | ||||||
|     serializer_class = UserSerializer |     serializer_class = UserSerializer | ||||||
|     filterset_class = UsersFilter |  | ||||||
|     search_fields = ["username", "name", "is_active", "email", "uuid", "attributes"] |     search_fields = ["username", "name", "is_active", "email", "uuid", "attributes"] | ||||||
|  |     filterset_class = UsersFilter | ||||||
|     def get_ql_fields(self): |  | ||||||
|         from djangoql.schema import BoolField, StrField |  | ||||||
|  |  | ||||||
|         from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField |  | ||||||
|  |  | ||||||
|         return [ |  | ||||||
|             StrField(User, "username"), |  | ||||||
|             StrField(User, "name"), |  | ||||||
|             StrField(User, "email"), |  | ||||||
|             StrField(User, "path"), |  | ||||||
|             BoolField(User, "is_active", nullable=True), |  | ||||||
|             ChoiceSearchField(User, "type"), |  | ||||||
|             JSONSearchField(User, "attributes", suggest_nested=False), |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|     def get_queryset(self): |     def get_queryset(self): | ||||||
|         base_qs = User.objects.all().exclude_anonymous() |         base_qs = User.objects.all().exclude_anonymous() | ||||||
|  | |||||||
| @ -2,7 +2,6 @@ | |||||||
|  |  | ||||||
| from typing import Any | from typing import Any | ||||||
|  |  | ||||||
| from django.db import models |  | ||||||
| from django.db.models import Model | from django.db.models import Model | ||||||
| from drf_spectacular.extensions import OpenApiSerializerFieldExtension | from drf_spectacular.extensions import OpenApiSerializerFieldExtension | ||||||
| from drf_spectacular.plumbing import build_basic_type | from drf_spectacular.plumbing import build_basic_type | ||||||
| @ -31,27 +30,7 @@ def is_dict(value: Any): | |||||||
|     raise ValidationError("Value must be a dictionary, and not have any duplicate keys.") |     raise ValidationError("Value must be a dictionary, and not have any duplicate keys.") | ||||||
|  |  | ||||||
|  |  | ||||||
| class JSONDictField(JSONField): |  | ||||||
|     """JSON Field which only allows dictionaries""" |  | ||||||
|  |  | ||||||
|     default_validators = [is_dict] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class JSONExtension(OpenApiSerializerFieldExtension): |  | ||||||
|     """Generate API Schema for JSON fields as""" |  | ||||||
|  |  | ||||||
|     target_class = "authentik.core.api.utils.JSONDictField" |  | ||||||
|  |  | ||||||
|     def map_serializer_field(self, auto_schema, direction): |  | ||||||
|         return build_basic_type(OpenApiTypes.OBJECT) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ModelSerializer(BaseModelSerializer): | class ModelSerializer(BaseModelSerializer): | ||||||
|  |  | ||||||
|     # By default, JSON fields we have are used to store dictionaries |  | ||||||
|     serializer_field_mapping = BaseModelSerializer.serializer_field_mapping.copy() |  | ||||||
|     serializer_field_mapping[models.JSONField] = JSONDictField |  | ||||||
|  |  | ||||||
|     def create(self, validated_data): |     def create(self, validated_data): | ||||||
|         instance = super().create(validated_data) |         instance = super().create(validated_data) | ||||||
|  |  | ||||||
| @ -92,6 +71,21 @@ class ModelSerializer(BaseModelSerializer): | |||||||
|         return instance |         return instance | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class JSONDictField(JSONField): | ||||||
|  |     """JSON Field which only allows dictionaries""" | ||||||
|  |  | ||||||
|  |     default_validators = [is_dict] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class JSONExtension(OpenApiSerializerFieldExtension): | ||||||
|  |     """Generate API Schema for JSON fields as""" | ||||||
|  |  | ||||||
|  |     target_class = "authentik.core.api.utils.JSONDictField" | ||||||
|  |  | ||||||
|  |     def map_serializer_field(self, auto_schema, direction): | ||||||
|  |         return build_basic_type(OpenApiTypes.OBJECT) | ||||||
|  |  | ||||||
|  |  | ||||||
| class PassiveSerializer(Serializer): | class PassiveSerializer(Serializer): | ||||||
|     """Base serializer class which doesn't implement create/update methods""" |     """Base serializer class which doesn't implement create/update methods""" | ||||||
|  |  | ||||||
|  | |||||||
| @ -13,6 +13,7 @@ class Command(TenantCommand): | |||||||
|         parser.add_argument("usernames", nargs="*", type=str) |         parser.add_argument("usernames", nargs="*", type=str) | ||||||
|  |  | ||||||
|     def handle_per_tenant(self, **options): |     def handle_per_tenant(self, **options): | ||||||
|  |         print(options) | ||||||
|         new_type = UserTypes(options["type"]) |         new_type = UserTypes(options["type"]) | ||||||
|         qs = ( |         qs = ( | ||||||
|             User.objects.exclude_anonymous() |             User.objects.exclude_anonymous() | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ from django.http import HttpRequest | |||||||
| from django.utils.functional import SimpleLazyObject, cached_property | from django.utils.functional import SimpleLazyObject, cached_property | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from django_cte import CTE, with_cte | from django_cte import CTEQuerySet, With | ||||||
| from guardian.conf import settings | from guardian.conf import settings | ||||||
| from guardian.mixins import GuardianUserMixin | from guardian.mixins import GuardianUserMixin | ||||||
| from model_utils.managers import InheritanceManager | from model_utils.managers import InheritanceManager | ||||||
| @ -136,7 +136,7 @@ class AttributesMixin(models.Model): | |||||||
|         return instance, False |         return instance, False | ||||||
|  |  | ||||||
|  |  | ||||||
| class GroupQuerySet(QuerySet): | class GroupQuerySet(CTEQuerySet): | ||||||
|     def with_children_recursive(self): |     def with_children_recursive(self): | ||||||
|         """Recursively get all groups that have the current queryset as parents |         """Recursively get all groups that have the current queryset as parents | ||||||
|         or are indirectly related.""" |         or are indirectly related.""" | ||||||
| @ -165,9 +165,9 @@ class GroupQuerySet(QuerySet): | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         # Build the recursive query, see above |         # Build the recursive query, see above | ||||||
|         cte = CTE.recursive(make_cte) |         cte = With.recursive(make_cte) | ||||||
|         # Return the result, as a usable queryset for Group. |         # Return the result, as a usable queryset for Group. | ||||||
|         return with_cte(cte, select=cte.join(Group, group_uuid=cte.col.group_uuid)) |         return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Group(SerializerModel, AttributesMixin): | class Group(SerializerModel, AttributesMixin): | ||||||
| @ -1082,12 +1082,6 @@ class AuthenticatedSession(SerializerModel): | |||||||
|  |  | ||||||
|     user = models.ForeignKey(User, on_delete=models.CASCADE) |     user = models.ForeignKey(User, on_delete=models.CASCADE) | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def serializer(self) -> type[Serializer]: |  | ||||||
|         from authentik.core.api.authenticated_sessions import AuthenticatedSessionSerializer |  | ||||||
|  |  | ||||||
|         return AuthenticatedSessionSerializer |  | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         verbose_name = _("Authenticated Session") |         verbose_name = _("Authenticated Session") | ||||||
|         verbose_name_plural = _("Authenticated Sessions") |         verbose_name_plural = _("Authenticated Sessions") | ||||||
|  | |||||||
| @ -114,7 +114,6 @@ class TestApplicationsAPI(APITestCase): | |||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             response.content.decode(), |             response.content.decode(), | ||||||
|             { |             { | ||||||
|                 "autocomplete": {}, |  | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                     "next": 0, |                     "next": 0, | ||||||
|                     "previous": 0, |                     "previous": 0, | ||||||
| @ -168,7 +167,6 @@ class TestApplicationsAPI(APITestCase): | |||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             response.content.decode(), |             response.content.decode(), | ||||||
|             { |             { | ||||||
|                 "autocomplete": {}, |  | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                     "next": 0, |                     "next": 0, | ||||||
|                     "previous": 0, |                     "previous": 0, | ||||||
|  | |||||||
| @ -119,17 +119,17 @@ class TestTrimPasswordHistory(TestCase): | |||||||
|             [ |             [ | ||||||
|                 UserPasswordHistory( |                 UserPasswordHistory( | ||||||
|                     user=self.user, |                     user=self.user, | ||||||
|                     old_password="hunter1",  # nosec |                     old_password="hunter1",  # nosec B106 | ||||||
|                     created_at=_now - timedelta(days=3), |                     created_at=_now - timedelta(days=3), | ||||||
|                 ), |                 ), | ||||||
|                 UserPasswordHistory( |                 UserPasswordHistory( | ||||||
|                     user=self.user, |                     user=self.user, | ||||||
|                     old_password="hunter2",  # nosec |                     old_password="hunter2",  # nosec B106 | ||||||
|                     created_at=_now - timedelta(days=2), |                     created_at=_now - timedelta(days=2), | ||||||
|                 ), |                 ), | ||||||
|                 UserPasswordHistory( |                 UserPasswordHistory( | ||||||
|                     user=self.user, |                     user=self.user, | ||||||
|                     old_password="hunter3",  # nosec |                     old_password="hunter3",  # nosec B106 | ||||||
|                     created_at=_now, |                     created_at=_now, | ||||||
|                 ), |                 ), | ||||||
|             ] |             ] | ||||||
|  | |||||||
| @ -1,32 +0,0 @@ | |||||||
| """Apple Platform SSO Provider API Views""" |  | ||||||
|  |  | ||||||
| from rest_framework.viewsets import ModelViewSet |  | ||||||
|  |  | ||||||
| from authentik.core.api.providers import ProviderSerializer |  | ||||||
| from authentik.core.api.used_by import UsedByMixin |  | ||||||
| from authentik.enterprise.api import EnterpriseRequiredMixin |  | ||||||
| from authentik.enterprise.providers.apple_psso.models import ApplePlatformSSOProvider |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ApplePlatformSSOProviderSerializer(EnterpriseRequiredMixin, ProviderSerializer): |  | ||||||
|     """ApplePlatformSSOProvider Serializer""" |  | ||||||
|  |  | ||||||
|     class Meta: |  | ||||||
|         model = ApplePlatformSSOProvider |  | ||||||
|         fields = [ |  | ||||||
|             "pk", |  | ||||||
|             "name", |  | ||||||
|         ] |  | ||||||
|         extra_kwargs = {} |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ApplePlatformSSOProviderViewSet(UsedByMixin, ModelViewSet): |  | ||||||
|     """ApplePlatformSSOProvider Viewset""" |  | ||||||
|  |  | ||||||
|     queryset = ApplePlatformSSOProvider.objects.all() |  | ||||||
|     serializer_class = ApplePlatformSSOProviderSerializer |  | ||||||
|     filterset_fields = [ |  | ||||||
|         "name", |  | ||||||
|     ] |  | ||||||
|     search_fields = ["name"] |  | ||||||
|     ordering = ["name"] |  | ||||||
| @ -1,13 +0,0 @@ | |||||||
| from authentik.enterprise.apps import EnterpriseConfig |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikEnterpriseProviderApplePSSOConfig(EnterpriseConfig): |  | ||||||
|  |  | ||||||
|     name = "authentik.enterprise.providers.apple_psso" |  | ||||||
|     label = "authentik_providers_apple_psso" |  | ||||||
|     verbose_name = "authentik Enterprise.Providers.Apple Platform SSO" |  | ||||||
|     default = True |  | ||||||
|     mountpoints = { |  | ||||||
|         "authentik.enterprise.providers.apple_psso.urls": "endpoint/apple/sso/", |  | ||||||
|         "authentik.enterprise.providers.apple_psso.urls_root": "", |  | ||||||
|     } |  | ||||||
| @ -1,118 +0,0 @@ | |||||||
| from base64 import urlsafe_b64encode |  | ||||||
| from json import dumps |  | ||||||
| from secrets import token_bytes |  | ||||||
|  |  | ||||||
| from cryptography.hazmat.backends import default_backend |  | ||||||
| from cryptography.hazmat.primitives import hashes, serialization |  | ||||||
| from cryptography.hazmat.primitives.asymmetric import ec |  | ||||||
| from cryptography.hazmat.primitives.ciphers.aead import AESGCM |  | ||||||
| from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash |  | ||||||
| from django.http import HttpResponse |  | ||||||
| from jwcrypto.common import base64url_decode, base64url_encode |  | ||||||
|  |  | ||||||
| from authentik.enterprise.providers.apple_psso.models import AppleDevice |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def length_prefixed(data: bytes) -> bytes: |  | ||||||
|     length = len(data) |  | ||||||
|     return length.to_bytes(4, "big") + data |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def build_apu(public_key: ec.EllipticCurvePublicKey): |  | ||||||
|     # X9.63 representation: 0x04 || X || Y |  | ||||||
|     public_numbers = public_key.public_numbers() |  | ||||||
|  |  | ||||||
|     x_bytes = public_numbers.x.to_bytes(32, "big") |  | ||||||
|     y_bytes = public_numbers.y.to_bytes(32, "big") |  | ||||||
|  |  | ||||||
|     x963 = bytes([0x04]) + x_bytes + y_bytes |  | ||||||
|  |  | ||||||
|     result = length_prefixed(b"APPLE") + length_prefixed(x963) |  | ||||||
|  |  | ||||||
|     return result |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def encrypt_token_with_a256_gcm(body: dict, device_encryption_key: str, apv: bytes) -> str: |  | ||||||
|     ephemeral_key = ec.generate_private_key(curve=ec.SECP256R1()) |  | ||||||
|     device_public_key = serialization.load_pem_public_key( |  | ||||||
|         device_encryption_key.encode(), backend=default_backend() |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     shared_secret_z = ephemeral_key.exchange(ec.ECDH(), device_public_key) |  | ||||||
|  |  | ||||||
|     apu = build_apu(ephemeral_key.public_key()) |  | ||||||
|  |  | ||||||
|     jwe_header = { |  | ||||||
|         "enc": "A256GCM", |  | ||||||
|         "kid": "ephemeralKey", |  | ||||||
|         "epk": { |  | ||||||
|             "x": base64url_encode( |  | ||||||
|                 ephemeral_key.public_key().public_numbers().x.to_bytes(32, "big") |  | ||||||
|             ), |  | ||||||
|             "y": base64url_encode( |  | ||||||
|                 ephemeral_key.public_key().public_numbers().y.to_bytes(32, "big") |  | ||||||
|             ), |  | ||||||
|             "kty": "EC", |  | ||||||
|             "crv": "P-256", |  | ||||||
|         }, |  | ||||||
|         "typ": "platformsso-login-response+jwt", |  | ||||||
|         "alg": "ECDH-ES", |  | ||||||
|         "apu": base64url_encode(apu), |  | ||||||
|         "apv": base64url_encode(apv), |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     party_u_info = length_prefixed(apu) |  | ||||||
|     party_v_info = length_prefixed(apv) |  | ||||||
|     supp_pub_info = (256).to_bytes(4, "big") |  | ||||||
|  |  | ||||||
|     other_info = length_prefixed(b"A256GCM") + party_u_info + party_v_info + supp_pub_info |  | ||||||
|  |  | ||||||
|     ckdf = ConcatKDFHash( |  | ||||||
|         algorithm=hashes.SHA256(), |  | ||||||
|         length=32, |  | ||||||
|         otherinfo=other_info, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     derived_key = ckdf.derive(shared_secret_z) |  | ||||||
|  |  | ||||||
|     nonce = token_bytes(12) |  | ||||||
|  |  | ||||||
|     header_json = dumps(jwe_header, separators=(",", ":")).encode() |  | ||||||
|     aad = urlsafe_b64encode(header_json).rstrip(b"=") |  | ||||||
|  |  | ||||||
|     aesgcm = AESGCM(derived_key) |  | ||||||
|     ciphertext = aesgcm.encrypt(nonce, dumps(body).encode(), aad) |  | ||||||
|  |  | ||||||
|     ciphertext_body = ciphertext[:-16] |  | ||||||
|     tag = ciphertext[-16:] |  | ||||||
|  |  | ||||||
|     # base64url encoding |  | ||||||
|     protected_b64 = urlsafe_b64encode(header_json).rstrip(b"=") |  | ||||||
|     iv_b64 = urlsafe_b64encode(nonce).rstrip(b"=") |  | ||||||
|     ciphertext_b64 = urlsafe_b64encode(ciphertext_body).rstrip(b"=") |  | ||||||
|     tag_b64 = urlsafe_b64encode(tag).rstrip(b"=") |  | ||||||
|  |  | ||||||
|     jwe_compact = b".".join( |  | ||||||
|         [ |  | ||||||
|             protected_b64, |  | ||||||
|             b"", |  | ||||||
|             iv_b64, |  | ||||||
|             ciphertext_b64, |  | ||||||
|             tag_b64, |  | ||||||
|         ] |  | ||||||
|     ) |  | ||||||
|     return jwe_compact.decode() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class JWEResponse(HttpResponse): |  | ||||||
|  |  | ||||||
|     def __init__( |  | ||||||
|         self, |  | ||||||
|         data: dict, |  | ||||||
|         device: AppleDevice, |  | ||||||
|         apv: str, |  | ||||||
|     ): |  | ||||||
|         super().__init__( |  | ||||||
|             content=encrypt_token_with_a256_gcm(data, device.encryption_key, base64url_decode(apv)), |  | ||||||
|             content_type="application/platformsso-login-response+jwt", |  | ||||||
|         ) |  | ||||||
| @ -1,36 +0,0 @@ | |||||||
| # Generated by Django 5.1.11 on 2025-06-28 00:12 |  | ||||||
|  |  | ||||||
| import django.db.models.deletion |  | ||||||
| from django.db import migrations, models |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     initial = True |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ("authentik_providers_oauth2", "0028_migrate_session"), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.CreateModel( |  | ||||||
|             name="ApplePlatformSSOProvider", |  | ||||||
|             fields=[ |  | ||||||
|                 ( |  | ||||||
|                     "oauth2provider_ptr", |  | ||||||
|                     models.OneToOneField( |  | ||||||
|                         auto_created=True, |  | ||||||
|                         on_delete=django.db.models.deletion.CASCADE, |  | ||||||
|                         parent_link=True, |  | ||||||
|                         primary_key=True, |  | ||||||
|                         serialize=False, |  | ||||||
|                         to="authentik_providers_oauth2.oauth2provider", |  | ||||||
|                     ), |  | ||||||
|                 ), |  | ||||||
|             ], |  | ||||||
|             options={ |  | ||||||
|                 "abstract": False, |  | ||||||
|             }, |  | ||||||
|             bases=("authentik_providers_oauth2.oauth2provider",), |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -1,94 +0,0 @@ | |||||||
| # Generated by Django 5.1.11 on 2025-06-28 15:50 |  | ||||||
|  |  | ||||||
| import django.db.models.deletion |  | ||||||
| import uuid |  | ||||||
| from django.conf import settings |  | ||||||
| from django.db import migrations, models |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ("authentik_providers_apple_psso", "0001_initial"), |  | ||||||
|         migrations.swappable_dependency(settings.AUTH_USER_MODEL), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.CreateModel( |  | ||||||
|             name="AppleDevice", |  | ||||||
|             fields=[ |  | ||||||
|                 ( |  | ||||||
|                     "endpoint_uuid", |  | ||||||
|                     models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False), |  | ||||||
|                 ), |  | ||||||
|                 ("signing_key", models.TextField()), |  | ||||||
|                 ("encryption_key", models.TextField()), |  | ||||||
|                 ("key_exchange_key", models.TextField()), |  | ||||||
|                 ("sign_key_id", models.TextField()), |  | ||||||
|                 ("enc_key_id", models.TextField()), |  | ||||||
|                 ("creation_time", models.DateTimeField(auto_now_add=True)), |  | ||||||
|                 ( |  | ||||||
|                     "provider", |  | ||||||
|                     models.ForeignKey( |  | ||||||
|                         on_delete=django.db.models.deletion.CASCADE, |  | ||||||
|                         to="authentik_providers_apple_psso.appleplatformssoprovider", |  | ||||||
|                     ), |  | ||||||
|                 ), |  | ||||||
|             ], |  | ||||||
|         ), |  | ||||||
|         migrations.CreateModel( |  | ||||||
|             name="AppleDeviceUser", |  | ||||||
|             fields=[ |  | ||||||
|                 ("uuid", models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)), |  | ||||||
|                 ("signing_key", models.TextField()), |  | ||||||
|                 ("encryption_key", models.TextField()), |  | ||||||
|                 ("sign_key_id", models.TextField()), |  | ||||||
|                 ("enc_key_id", models.TextField()), |  | ||||||
|                 ( |  | ||||||
|                     "device", |  | ||||||
|                     models.ForeignKey( |  | ||||||
|                         on_delete=django.db.models.deletion.CASCADE, |  | ||||||
|                         to="authentik_providers_apple_psso.appledevice", |  | ||||||
|                     ), |  | ||||||
|                 ), |  | ||||||
|                 ( |  | ||||||
|                     "user", |  | ||||||
|                     models.ForeignKey( |  | ||||||
|                         on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL |  | ||||||
|                     ), |  | ||||||
|                 ), |  | ||||||
|             ], |  | ||||||
|         ), |  | ||||||
|         migrations.AddField( |  | ||||||
|             model_name="appledevice", |  | ||||||
|             name="users", |  | ||||||
|             field=models.ManyToManyField( |  | ||||||
|                 through="authentik_providers_apple_psso.AppleDeviceUser", |  | ||||||
|                 to=settings.AUTH_USER_MODEL, |  | ||||||
|             ), |  | ||||||
|         ), |  | ||||||
|         migrations.CreateModel( |  | ||||||
|             name="AppleNonce", |  | ||||||
|             fields=[ |  | ||||||
|                 ( |  | ||||||
|                     "id", |  | ||||||
|                     models.AutoField( |  | ||||||
|                         auto_created=True, primary_key=True, serialize=False, verbose_name="ID" |  | ||||||
|                     ), |  | ||||||
|                 ), |  | ||||||
|                 ("expires", models.DateTimeField(default=None, null=True)), |  | ||||||
|                 ("expiring", models.BooleanField(default=True)), |  | ||||||
|                 ("nonce", models.TextField()), |  | ||||||
|             ], |  | ||||||
|             options={ |  | ||||||
|                 "abstract": False, |  | ||||||
|                 "indexes": [ |  | ||||||
|                     models.Index(fields=["expires"], name="authentik_p_expires_47d534_idx"), |  | ||||||
|                     models.Index(fields=["expiring"], name="authentik_p_expirin_87253e_idx"), |  | ||||||
|                     models.Index( |  | ||||||
|                         fields=["expiring", "expires"], name="authentik_p_expirin_20a7c9_idx" |  | ||||||
|                     ), |  | ||||||
|                 ], |  | ||||||
|             }, |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -1,34 +0,0 @@ | |||||||
| # Generated by Django 5.1.11 on 2025-06-28 22:18 |  | ||||||
|  |  | ||||||
| from django.db import migrations |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ( |  | ||||||
|             "authentik_providers_apple_psso", |  | ||||||
|             "0002_appledevice_appledeviceuser_appledevice_users_and_more", |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.RenameField( |  | ||||||
|             model_name="appledeviceuser", |  | ||||||
|             old_name="sign_key_id", |  | ||||||
|             new_name="enclave_key_id", |  | ||||||
|         ), |  | ||||||
|         migrations.RenameField( |  | ||||||
|             model_name="appledeviceuser", |  | ||||||
|             old_name="signing_key", |  | ||||||
|             new_name="secure_enclave_key", |  | ||||||
|         ), |  | ||||||
|         migrations.RemoveField( |  | ||||||
|             model_name="appledeviceuser", |  | ||||||
|             name="enc_key_id", |  | ||||||
|         ), |  | ||||||
|         migrations.RemoveField( |  | ||||||
|             model_name="appledeviceuser", |  | ||||||
|             name="encryption_key", |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -1,85 +0,0 @@ | |||||||
| from uuid import uuid4 |  | ||||||
|  |  | ||||||
| from django.db import models |  | ||||||
| from django.utils.translation import gettext_lazy as _ |  | ||||||
| from rest_framework.serializers import Serializer |  | ||||||
|  |  | ||||||
| from authentik.core.models import ExpiringModel, User |  | ||||||
| from authentik.crypto.models import CertificateKeyPair |  | ||||||
| from authentik.providers.oauth2.models import ( |  | ||||||
|     ClientTypes, |  | ||||||
|     IssuerMode, |  | ||||||
|     OAuth2Provider, |  | ||||||
|     RedirectURI, |  | ||||||
|     RedirectURIMatchingMode, |  | ||||||
|     ScopeMapping, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ApplePlatformSSOProvider(OAuth2Provider): |  | ||||||
|     """Integrate with Apple Platform SSO""" |  | ||||||
|  |  | ||||||
|     def set_oauth_defaults(self): |  | ||||||
|         """Ensure all OAuth2-related settings are correct""" |  | ||||||
|         self.issuer_mode = IssuerMode.PER_PROVIDER |  | ||||||
|         self.client_type = ClientTypes.PUBLIC |  | ||||||
|         self.signing_key = CertificateKeyPair.objects.get(name="authentik Self-signed Certificate") |  | ||||||
|         self.include_claims_in_id_token = True |  | ||||||
|         scopes = ScopeMapping.objects.filter( |  | ||||||
|             managed__in=[ |  | ||||||
|                 "goauthentik.io/providers/oauth2/scope-openid", |  | ||||||
|                 "goauthentik.io/providers/oauth2/scope-profile", |  | ||||||
|                 "goauthentik.io/providers/oauth2/scope-email", |  | ||||||
|                 "goauthentik.io/providers/oauth2/scope-offline_access", |  | ||||||
|                 "goauthentik.io/providers/oauth2/scope-authentik_api", |  | ||||||
|             ] |  | ||||||
|         ) |  | ||||||
|         self.property_mappings.add(*list(scopes)) |  | ||||||
|         self.redirect_uris = [ |  | ||||||
|             RedirectURI(RedirectURIMatchingMode.STRICT, "io.goauthentik.endpoint:/oauth2redirect"), |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def component(self) -> str: |  | ||||||
|         return "ak-provider-apple-psso-form" |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def serializer(self) -> type[Serializer]: |  | ||||||
|         from authentik.enterprise.providers.apple_psso.api.providers import ( |  | ||||||
|             ApplePlatformSSOProviderSerializer, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         return ApplePlatformSSOProviderSerializer |  | ||||||
|  |  | ||||||
|     class Meta: |  | ||||||
|         verbose_name = _("Apple Platform SSO Provider") |  | ||||||
|         verbose_name_plural = _("Apple Platform SSO Providers") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AppleDevice(models.Model): |  | ||||||
|  |  | ||||||
|     endpoint_uuid = models.UUIDField(default=uuid4, primary_key=True) |  | ||||||
|  |  | ||||||
|     signing_key = models.TextField() |  | ||||||
|     encryption_key = models.TextField() |  | ||||||
|     key_exchange_key = models.TextField() |  | ||||||
|     sign_key_id = models.TextField() |  | ||||||
|     enc_key_id = models.TextField() |  | ||||||
|     creation_time = models.DateTimeField(auto_now_add=True) |  | ||||||
|     provider = models.ForeignKey(ApplePlatformSSOProvider, on_delete=models.CASCADE) |  | ||||||
|     users = models.ManyToManyField(User, through="AppleDeviceUser") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AppleDeviceUser(models.Model): |  | ||||||
|  |  | ||||||
|     uuid = models.UUIDField(default=uuid4, primary_key=True) |  | ||||||
|  |  | ||||||
|     device = models.ForeignKey(AppleDevice, on_delete=models.CASCADE) |  | ||||||
|     user = models.ForeignKey(User, on_delete=models.CASCADE) |  | ||||||
|  |  | ||||||
|     secure_enclave_key = models.TextField() |  | ||||||
|     enclave_key_id = models.TextField() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AppleNonce(ExpiringModel): |  | ||||||
|     nonce = models.TextField() |  | ||||||
| @ -1,15 +0,0 @@ | |||||||
| from django.urls import path |  | ||||||
|  |  | ||||||
| from authentik.enterprise.providers.apple_psso.views.nonce import NonceView |  | ||||||
| from authentik.enterprise.providers.apple_psso.views.register import ( |  | ||||||
|     RegisterDeviceView, |  | ||||||
|     RegisterUserView, |  | ||||||
| ) |  | ||||||
| from authentik.enterprise.providers.apple_psso.views.token import TokenView |  | ||||||
|  |  | ||||||
| urlpatterns = [ |  | ||||||
|     path("token/", TokenView.as_view(), name="token"), |  | ||||||
|     path("nonce/", NonceView.as_view(), name="nonce"), |  | ||||||
|     path("register/device/", RegisterDeviceView.as_view(), name="register-device"), |  | ||||||
|     path("register/user/", RegisterUserView.as_view(), name="register-user"), |  | ||||||
| ] |  | ||||||
| @ -1,7 +0,0 @@ | |||||||
| from django.urls import path |  | ||||||
|  |  | ||||||
| from authentik.enterprise.providers.apple_psso.views.site_association import AppleAppSiteAssociation |  | ||||||
|  |  | ||||||
| urlpatterns = [ |  | ||||||
|     path(".well-known/apple-app-site-association", AppleAppSiteAssociation.as_view(), name="asa"), |  | ||||||
| ] |  | ||||||
| @ -1,25 +0,0 @@ | |||||||
| from base64 import b64encode |  | ||||||
| from datetime import timedelta |  | ||||||
| from secrets import token_bytes |  | ||||||
|  |  | ||||||
| from django.http import HttpRequest, JsonResponse |  | ||||||
| from django.utils.decorators import method_decorator |  | ||||||
| from django.utils.timezone import now |  | ||||||
| from django.views import View |  | ||||||
| from django.views.decorators.csrf import csrf_exempt |  | ||||||
|  |  | ||||||
| from authentik.enterprise.providers.apple_psso.models import AppleNonce |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @method_decorator(csrf_exempt, name="dispatch") |  | ||||||
| class NonceView(View): |  | ||||||
|  |  | ||||||
|     def post(self, request: HttpRequest, *args, **kwargs): |  | ||||||
|         nonce = AppleNonce.objects.create( |  | ||||||
|             nonce=b64encode(token_bytes(32)).decode(), expires=now() + timedelta(minutes=5) |  | ||||||
|         ) |  | ||||||
|         return JsonResponse( |  | ||||||
|             { |  | ||||||
|                 "Nonce": nonce.nonce, |  | ||||||
|             } |  | ||||||
|         ) |  | ||||||
| @ -1,92 +0,0 @@ | |||||||
| from django.shortcuts import get_object_or_404 |  | ||||||
| from rest_framework.authentication import BaseAuthentication |  | ||||||
| from rest_framework.fields import CharField |  | ||||||
| from rest_framework.request import Request |  | ||||||
| from rest_framework.response import Response |  | ||||||
| from rest_framework.views import APIView |  | ||||||
|  |  | ||||||
| from authentik.api.authentication import TokenAuthentication |  | ||||||
| from authentik.core.api.utils import PassiveSerializer |  | ||||||
| from authentik.core.models import User |  | ||||||
| from authentik.enterprise.providers.apple_psso.models import ( |  | ||||||
|     AppleDevice, |  | ||||||
|     AppleDeviceUser, |  | ||||||
|     ApplePlatformSSOProvider, |  | ||||||
| ) |  | ||||||
| from authentik.lib.generators import generate_key |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DeviceRegisterAuth(BaseAuthentication): |  | ||||||
|     def authenticate(self, request): |  | ||||||
|         # very temporary, lol |  | ||||||
|         return (User(), None) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class RegisterDeviceView(APIView): |  | ||||||
|  |  | ||||||
|     class DeviceRegistration(PassiveSerializer): |  | ||||||
|  |  | ||||||
|         device_uuid = CharField() |  | ||||||
|         client_id = CharField() |  | ||||||
|         device_signing_key = CharField() |  | ||||||
|         device_encryption_key = CharField() |  | ||||||
|         sign_key_id = CharField() |  | ||||||
|         enc_key_id = CharField() |  | ||||||
|  |  | ||||||
|     permission_classes = [] |  | ||||||
|     pagination_class = None |  | ||||||
|     filter_backends = [] |  | ||||||
|     serializer_class = DeviceRegistration |  | ||||||
|     authentication_classes = [DeviceRegisterAuth, TokenAuthentication] |  | ||||||
|  |  | ||||||
|     def post(self, request: Request) -> Response: |  | ||||||
|         data = self.DeviceRegistration(data=request.data) |  | ||||||
|         data.is_valid(raise_exception=True) |  | ||||||
|         provider = get_object_or_404( |  | ||||||
|             ApplePlatformSSOProvider, client_id=data.validated_data["client_id"] |  | ||||||
|         ) |  | ||||||
|         AppleDevice.objects.update_or_create( |  | ||||||
|             endpoint_uuid=data.validated_data["device_uuid"], |  | ||||||
|             defaults={ |  | ||||||
|                 "signing_key": data.validated_data["device_signing_key"], |  | ||||||
|                 "encryption_key": data.validated_data["device_encryption_key"], |  | ||||||
|                 "sign_key_id": data.validated_data["sign_key_id"], |  | ||||||
|                 "enc_key_id": data.validated_data["enc_key_id"], |  | ||||||
|                 "key_exchange_key": generate_key(), |  | ||||||
|                 "provider": provider, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         return Response() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class RegisterUserView(APIView): |  | ||||||
|  |  | ||||||
|     class UserRegistration(PassiveSerializer): |  | ||||||
|  |  | ||||||
|         device_uuid = CharField() |  | ||||||
|         user_secure_enclave_key = CharField() |  | ||||||
|         enclave_key_id = CharField() |  | ||||||
|  |  | ||||||
|     permission_classes = [] |  | ||||||
|     pagination_class = None |  | ||||||
|     filter_backends = [] |  | ||||||
|     serializer_class = UserRegistration |  | ||||||
|     authentication_classes = [TokenAuthentication] |  | ||||||
|  |  | ||||||
|     def post(self, request: Request) -> Response: |  | ||||||
|         data = self.UserRegistration(data=request.data) |  | ||||||
|         data.is_valid(raise_exception=True) |  | ||||||
|         device = get_object_or_404(AppleDevice, endpoint_uuid=data.validated_data["device_uuid"]) |  | ||||||
|         AppleDeviceUser.objects.update_or_create( |  | ||||||
|             device=device, |  | ||||||
|             user=request.user, |  | ||||||
|             defaults={ |  | ||||||
|                 "secure_enclave_key": data.validated_data["user_secure_enclave_key"], |  | ||||||
|                 "enclave_key_id": data.validated_data["enclave_key_id"], |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         return Response( |  | ||||||
|             { |  | ||||||
|                 "username": request.user.username, |  | ||||||
|             } |  | ||||||
|         ) |  | ||||||
| @ -1,16 +0,0 @@ | |||||||
| from django.http import HttpRequest, HttpResponse, JsonResponse |  | ||||||
| from django.views import View |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AppleAppSiteAssociation(View): |  | ||||||
|     def get(self, request: HttpRequest) -> HttpResponse: |  | ||||||
|         return JsonResponse( |  | ||||||
|             { |  | ||||||
|                 "authsrv": { |  | ||||||
|                     "apps": [ |  | ||||||
|                         "232G855Y8N.io.goauthentik.endpoint", |  | ||||||
|                         "232G855Y8N.io.goauthentik.endpoint.psso", |  | ||||||
|                     ] |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         ) |  | ||||||
| @ -1,140 +0,0 @@ | |||||||
| from datetime import timedelta |  | ||||||
|  |  | ||||||
| from django.http import Http404, HttpRequest, HttpResponse |  | ||||||
| from django.utils.decorators import method_decorator |  | ||||||
| from django.utils.timezone import now |  | ||||||
| from django.views import View |  | ||||||
| from django.views.decorators.csrf import csrf_exempt |  | ||||||
| from jwt import PyJWT, decode |  | ||||||
| from rest_framework.exceptions import ValidationError |  | ||||||
| from structlog.stdlib import get_logger |  | ||||||
|  |  | ||||||
| from authentik.core.models import AuthenticatedSession, Session, User |  | ||||||
| from authentik.core.sessions import SessionStore |  | ||||||
| from authentik.enterprise.providers.apple_psso.http import JWEResponse |  | ||||||
| from authentik.enterprise.providers.apple_psso.models import ( |  | ||||||
|     AppleDevice, |  | ||||||
|     AppleDeviceUser, |  | ||||||
|     AppleNonce, |  | ||||||
|     ApplePlatformSSOProvider, |  | ||||||
| ) |  | ||||||
| from authentik.events.models import Event, EventAction |  | ||||||
| from authentik.events.signals import SESSION_LOGIN_EVENT |  | ||||||
| from authentik.providers.oauth2.constants import TOKEN_TYPE |  | ||||||
| from authentik.providers.oauth2.id_token import IDToken |  | ||||||
| from authentik.providers.oauth2.models import RefreshToken |  | ||||||
| from authentik.root.middleware import SessionMiddleware |  | ||||||
|  |  | ||||||
| LOGGER = get_logger() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @method_decorator(csrf_exempt, name="dispatch") |  | ||||||
| class TokenView(View): |  | ||||||
|  |  | ||||||
|     device: AppleDevice |  | ||||||
|     provider: ApplePlatformSSOProvider |  | ||||||
|  |  | ||||||
|     def post(self, request: HttpRequest) -> HttpResponse: |  | ||||||
|         version = request.POST.get("platform_sso_version") |  | ||||||
|         assertion = request.POST.get("assertion", request.POST.get("request")) |  | ||||||
|         if not assertion: |  | ||||||
|             return HttpResponse(status=400) |  | ||||||
|  |  | ||||||
|         decode_unvalidated = PyJWT().decode_complete(assertion, options={"verify_signature": False}) |  | ||||||
|         LOGGER.debug(decode_unvalidated["header"]) |  | ||||||
|         expected_kid = decode_unvalidated["header"]["kid"] |  | ||||||
|  |  | ||||||
|         self.device = AppleDevice.objects.filter(sign_key_id=expected_kid).first() |  | ||||||
|         if not self.device: |  | ||||||
|             raise Http404 |  | ||||||
|         self.provider = self.device.provider |  | ||||||
|  |  | ||||||
|         # Properly decode the JWT with the key from the device |  | ||||||
|         decoded = decode( |  | ||||||
|             assertion, self.device.signing_key, algorithms=["ES256"], options={"verify_aud": False} |  | ||||||
|         ) |  | ||||||
|         LOGGER.debug(decoded) |  | ||||||
|  |  | ||||||
|         LOGGER.debug("got device", device=self.device) |  | ||||||
|  |  | ||||||
|         # Check that the nonce hasn't been used before |  | ||||||
|         nonce = AppleNonce.objects.filter(nonce=decoded["request_nonce"]).first() |  | ||||||
|         if not nonce: |  | ||||||
|             return HttpResponse(status=400) |  | ||||||
|         nonce.delete() |  | ||||||
|  |  | ||||||
|         handler_func = ( |  | ||||||
|             f"handle_v{version}_{decode_unvalidated["header"]["typ"]}".replace("-", "_") |  | ||||||
|             .replace("+", "_") |  | ||||||
|             .replace(".", "_") |  | ||||||
|         ) |  | ||||||
|         handler = getattr(self, handler_func, None) |  | ||||||
|         if not handler: |  | ||||||
|             LOGGER.debug("Handler not found", handler=handler_func) |  | ||||||
|             return HttpResponse(status=400) |  | ||||||
|         LOGGER.debug("sending to handler", handler=handler_func) |  | ||||||
|         return handler(decoded) |  | ||||||
|  |  | ||||||
|     def validate_device_user_response(self, assertion: str) -> tuple[AppleDeviceUser, dict] | None: |  | ||||||
|         """Decode an embedded assertion and validate it by looking up the matching device user""" |  | ||||||
|         decode_unvalidated = PyJWT().decode_complete(assertion, options={"verify_signature": False}) |  | ||||||
|         expected_kid = decode_unvalidated["header"]["kid"] |  | ||||||
|  |  | ||||||
|         device_user = AppleDeviceUser.objects.filter( |  | ||||||
|             device=self.device, enclave_key_id=expected_kid |  | ||||||
|         ).first() |  | ||||||
|         if not device_user: |  | ||||||
|             return None |  | ||||||
|         return device_user, decode( |  | ||||||
|             assertion, |  | ||||||
|             device_user.secure_enclave_key, |  | ||||||
|             audience="apple-platform-sso", |  | ||||||
|             algorithms=["ES256"], |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def create_auth_session(self, user: User): |  | ||||||
|         event = Event.new(EventAction.LOGIN).from_http(self.request, user=user) |  | ||||||
|         store = SessionStore() |  | ||||||
|         store[SESSION_LOGIN_EVENT] = event |  | ||||||
|         store.save() |  | ||||||
|         session = Session.objects.filter(session_key=store.session_key).first() |  | ||||||
|         AuthenticatedSession.objects.create(session=session, user=user) |  | ||||||
|         session = SessionMiddleware.encode_session(store.session_key, user) |  | ||||||
|         return session |  | ||||||
|  |  | ||||||
|     def handle_v1_0_platformsso_login_request_jwt(self, decoded: dict): |  | ||||||
|         user = None |  | ||||||
|         if decoded["grant_type"] == "urn:ietf:params:oauth:grant-type:jwt-bearer": |  | ||||||
|             # Decode and validate inner assertion |  | ||||||
|             user, inner = self.validate_device_user_response(decoded["assertion"]) |  | ||||||
|             if inner["nonce"] != decoded["nonce"]: |  | ||||||
|                 LOGGER.warning("Mis-matched nonce to outer assertion") |  | ||||||
|                 raise ValidationError("Invalid request") |  | ||||||
|  |  | ||||||
|         refresh_token = RefreshToken( |  | ||||||
|             user=user.user, |  | ||||||
|             scope=decoded["scope"], |  | ||||||
|             expires=now() + timedelta(hours=8), |  | ||||||
|             provider=self.provider, |  | ||||||
|             auth_time=now(), |  | ||||||
|             session=None, |  | ||||||
|         ) |  | ||||||
|         id_token = IDToken.new( |  | ||||||
|             self.provider, |  | ||||||
|             refresh_token, |  | ||||||
|             self.request, |  | ||||||
|         ) |  | ||||||
|         id_token.nonce = decoded["nonce"] |  | ||||||
|         refresh_token.id_token = id_token |  | ||||||
|         refresh_token.save() |  | ||||||
|         return JWEResponse( |  | ||||||
|             { |  | ||||||
|                 "refresh_token": refresh_token.token, |  | ||||||
|                 "refresh_token_expires_in": int((refresh_token.expires - now()).total_seconds()), |  | ||||||
|                 "id_token": refresh_token.id_token.to_jwt(self.provider), |  | ||||||
|                 "token_type": TOKEN_TYPE, |  | ||||||
|                 "session_key": self.create_auth_session(user.user), |  | ||||||
|             }, |  | ||||||
|             device=self.device, |  | ||||||
|             apv=decoded["jwe_crypto"]["apv"], |  | ||||||
|         ) |  | ||||||
| @ -1,8 +1,10 @@ | |||||||
| from hashlib import sha256 | from hashlib import sha256 | ||||||
|  |  | ||||||
|  | from django.contrib.auth.signals import user_logged_out | ||||||
| from django.db.models import Model | from django.db.models import Model | ||||||
| from django.db.models.signals import post_delete, post_save, pre_delete | from django.db.models.signals import post_delete, post_save, pre_delete | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
|  | from django.http.request import HttpRequest | ||||||
| from guardian.shortcuts import assign_perm | from guardian.shortcuts import assign_perm | ||||||
|  |  | ||||||
| from authentik.core.models import ( | from authentik.core.models import ( | ||||||
| @ -60,6 +62,31 @@ def ssf_providers_post_save(sender: type[Model], instance: SSFProvider, created: | |||||||
|             instance.save() |             instance.save() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(user_logged_out) | ||||||
|  | def ssf_user_logged_out_session_revoked(sender, request: HttpRequest, user: User, **_): | ||||||
|  |     """Session revoked trigger (user logged out)""" | ||||||
|  |     if not request.session or not request.session.session_key or not user: | ||||||
|  |         return | ||||||
|  |     send_ssf_event( | ||||||
|  |         EventTypes.CAEP_SESSION_REVOKED, | ||||||
|  |         { | ||||||
|  |             "initiating_entity": "user", | ||||||
|  |         }, | ||||||
|  |         sub_id={ | ||||||
|  |             "format": "complex", | ||||||
|  |             "session": { | ||||||
|  |                 "format": "opaque", | ||||||
|  |                 "id": sha256(request.session.session_key.encode("ascii")).hexdigest(), | ||||||
|  |             }, | ||||||
|  |             "user": { | ||||||
|  |                 "format": "email", | ||||||
|  |                 "email": user.email, | ||||||
|  |             }, | ||||||
|  |         }, | ||||||
|  |         request=request, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_delete, sender=AuthenticatedSession) | @receiver(pre_delete, sender=AuthenticatedSession) | ||||||
| def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSession, **_): | def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSession, **_): | ||||||
|     """Session revoked trigger (users' session has been deleted) |     """Session revoked trigger (users' session has been deleted) | ||||||
|  | |||||||
| @ -1,12 +0,0 @@ | |||||||
| """Enterprise app config""" |  | ||||||
|  |  | ||||||
| from authentik.enterprise.apps import EnterpriseConfig |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikEnterpriseSearchConfig(EnterpriseConfig): |  | ||||||
|     """Enterprise app config""" |  | ||||||
|  |  | ||||||
|     name = "authentik.enterprise.search" |  | ||||||
|     label = "authentik_search" |  | ||||||
|     verbose_name = "authentik Enterprise.Search" |  | ||||||
|     default = True |  | ||||||
| @ -1,128 +0,0 @@ | |||||||
| """DjangoQL search""" |  | ||||||
|  |  | ||||||
| from collections import OrderedDict, defaultdict |  | ||||||
| from collections.abc import Generator |  | ||||||
|  |  | ||||||
| from django.db import connection |  | ||||||
| from django.db.models import Model, Q |  | ||||||
| from djangoql.compat import text_type |  | ||||||
| from djangoql.schema import StrField |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class JSONSearchField(StrField): |  | ||||||
|     """JSON field for DjangoQL""" |  | ||||||
|  |  | ||||||
|     model: Model |  | ||||||
|  |  | ||||||
|     def __init__(self, model=None, name=None, nullable=None, suggest_nested=True): |  | ||||||
|         # Set this in the constructor to not clobber the type variable |  | ||||||
|         self.type = "relation" |  | ||||||
|         self.suggest_nested = suggest_nested |  | ||||||
|         super().__init__(model, name, nullable) |  | ||||||
|  |  | ||||||
|     def get_lookup(self, path, operator, value): |  | ||||||
|         search = "__".join(path) |  | ||||||
|         op, invert = self.get_operator(operator) |  | ||||||
|         q = Q(**{f"{search}{op}": self.get_lookup_value(value)}) |  | ||||||
|         return ~q if invert else q |  | ||||||
|  |  | ||||||
|     def json_field_keys(self) -> Generator[tuple[str]]: |  | ||||||
|         with connection.cursor() as cursor: |  | ||||||
|             cursor.execute( |  | ||||||
|                 f""" |  | ||||||
|                 WITH RECURSIVE "{self.name}_keys" AS ( |  | ||||||
|                     SELECT |  | ||||||
|                         ARRAY[jsonb_object_keys("{self.name}")] AS key_path_array, |  | ||||||
|                         "{self.name}" -> jsonb_object_keys("{self.name}") AS value |  | ||||||
|                     FROM {self.model._meta.db_table} |  | ||||||
|                     WHERE "{self.name}" IS NOT NULL |  | ||||||
|                         AND jsonb_typeof("{self.name}") = 'object' |  | ||||||
|  |  | ||||||
|                     UNION ALL |  | ||||||
|  |  | ||||||
|                     SELECT |  | ||||||
|                         ck.key_path_array || jsonb_object_keys(ck.value), |  | ||||||
|                         ck.value -> jsonb_object_keys(ck.value) AS value |  | ||||||
|                     FROM "{self.name}_keys" ck |  | ||||||
|                     WHERE jsonb_typeof(ck.value) = 'object' |  | ||||||
|                 ), |  | ||||||
|  |  | ||||||
|                 unique_paths AS ( |  | ||||||
|                     SELECT DISTINCT key_path_array |  | ||||||
|                     FROM "{self.name}_keys" |  | ||||||
|                 ) |  | ||||||
|  |  | ||||||
|                 SELECT key_path_array FROM unique_paths; |  | ||||||
|             """  # nosec |  | ||||||
|             ) |  | ||||||
|             return (x[0] for x in cursor.fetchall()) |  | ||||||
|  |  | ||||||
|     def get_nested_options(self) -> OrderedDict: |  | ||||||
|         """Get keys of all nested objects to show autocomplete""" |  | ||||||
|         if not self.suggest_nested: |  | ||||||
|             return OrderedDict() |  | ||||||
|         base_model_name = f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}" |  | ||||||
|  |  | ||||||
|         def recursive_function(parts: list[str], parent_parts: list[str] | None = None): |  | ||||||
|             if not parent_parts: |  | ||||||
|                 parent_parts = [] |  | ||||||
|             path = parts.pop(0) |  | ||||||
|             parent_parts.append(path) |  | ||||||
|             relation_key = "_".join(parent_parts) |  | ||||||
|             if len(parts) > 1: |  | ||||||
|                 out_dict = { |  | ||||||
|                     relation_key: { |  | ||||||
|                         parts[0]: { |  | ||||||
|                             "type": "relation", |  | ||||||
|                             "relation": f"{relation_key}_{parts[0]}", |  | ||||||
|                         } |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|                 child_paths = recursive_function(parts.copy(), parent_parts.copy()) |  | ||||||
|                 child_paths.update(out_dict) |  | ||||||
|                 return child_paths |  | ||||||
|             else: |  | ||||||
|                 return {relation_key: {parts[0]: {}}} |  | ||||||
|  |  | ||||||
|         relation_structure = defaultdict(dict) |  | ||||||
|  |  | ||||||
|         for relations in self.json_field_keys(): |  | ||||||
|             result = recursive_function([base_model_name] + relations) |  | ||||||
|             for relation_key, value in result.items(): |  | ||||||
|                 for sub_relation_key, sub_value in value.items(): |  | ||||||
|                     if not relation_structure[relation_key].get(sub_relation_key, None): |  | ||||||
|                         relation_structure[relation_key][sub_relation_key] = sub_value |  | ||||||
|                     else: |  | ||||||
|                         relation_structure[relation_key][sub_relation_key].update(sub_value) |  | ||||||
|  |  | ||||||
|         final_dict = defaultdict(dict) |  | ||||||
|  |  | ||||||
|         for key, value in relation_structure.items(): |  | ||||||
|             for sub_key, sub_value in value.items(): |  | ||||||
|                 if not sub_value: |  | ||||||
|                     final_dict[key][sub_key] = { |  | ||||||
|                         "type": "str", |  | ||||||
|                         "nullable": True, |  | ||||||
|                     } |  | ||||||
|                 else: |  | ||||||
|                     final_dict[key][sub_key] = sub_value |  | ||||||
|         return OrderedDict(final_dict) |  | ||||||
|  |  | ||||||
|     def relation(self) -> str: |  | ||||||
|         return f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ChoiceSearchField(StrField): |  | ||||||
|     def __init__(self, model=None, name=None, nullable=None): |  | ||||||
|         super().__init__(model, name, nullable, suggest_options=True) |  | ||||||
|  |  | ||||||
|     def get_options(self, search): |  | ||||||
|         result = [] |  | ||||||
|         choices = self._field_choices() |  | ||||||
|         if choices: |  | ||||||
|             search = search.lower() |  | ||||||
|             for c in choices: |  | ||||||
|                 choice = text_type(c[0]) |  | ||||||
|                 if search in choice.lower(): |  | ||||||
|                     result.append(choice) |  | ||||||
|         return result |  | ||||||
| @ -1,53 +0,0 @@ | |||||||
| from rest_framework.response import Response |  | ||||||
|  |  | ||||||
| from authentik.api.pagination import Pagination |  | ||||||
| from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, QLSearch |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AutocompletePagination(Pagination): |  | ||||||
|  |  | ||||||
|     def paginate_queryset(self, queryset, request, view=None): |  | ||||||
|         self.view = view |  | ||||||
|         return super().paginate_queryset(queryset, request, view) |  | ||||||
|  |  | ||||||
|     def get_autocomplete(self): |  | ||||||
|         schema = QLSearch().get_schema(self.request, self.view) |  | ||||||
|         introspections = {} |  | ||||||
|         if hasattr(self.view, "get_ql_fields"): |  | ||||||
|             from authentik.enterprise.search.schema import AKQLSchemaSerializer |  | ||||||
|  |  | ||||||
|             introspections = AKQLSchemaSerializer().serialize( |  | ||||||
|                 schema(self.page.paginator.object_list.model) |  | ||||||
|             ) |  | ||||||
|         return introspections |  | ||||||
|  |  | ||||||
|     def get_paginated_response(self, data): |  | ||||||
|         previous_page_number = 0 |  | ||||||
|         if self.page.has_previous(): |  | ||||||
|             previous_page_number = self.page.previous_page_number() |  | ||||||
|         next_page_number = 0 |  | ||||||
|         if self.page.has_next(): |  | ||||||
|             next_page_number = self.page.next_page_number() |  | ||||||
|         return Response( |  | ||||||
|             { |  | ||||||
|                 "pagination": { |  | ||||||
|                     "next": next_page_number, |  | ||||||
|                     "previous": previous_page_number, |  | ||||||
|                     "count": self.page.paginator.count, |  | ||||||
|                     "current": self.page.number, |  | ||||||
|                     "total_pages": self.page.paginator.num_pages, |  | ||||||
|                     "start_index": self.page.start_index(), |  | ||||||
|                     "end_index": self.page.end_index(), |  | ||||||
|                 }, |  | ||||||
|                 "results": data, |  | ||||||
|                 "autocomplete": self.get_autocomplete(), |  | ||||||
|             } |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def get_paginated_response_schema(self, schema): |  | ||||||
|         final_schema = super().get_paginated_response_schema(schema) |  | ||||||
|         final_schema["properties"]["autocomplete"] = { |  | ||||||
|             "$ref": f"#/components/schemas/{AUTOCOMPLETE_COMPONENT_NAME}" |  | ||||||
|         } |  | ||||||
|         final_schema["required"].append("autocomplete") |  | ||||||
|         return final_schema |  | ||||||
| @ -1,78 +0,0 @@ | |||||||
| """DjangoQL search""" |  | ||||||
|  |  | ||||||
| from django.apps import apps |  | ||||||
| from django.db.models import QuerySet |  | ||||||
| from djangoql.ast import Name |  | ||||||
| from djangoql.exceptions import DjangoQLError |  | ||||||
| from djangoql.queryset import apply_search |  | ||||||
| from djangoql.schema import DjangoQLSchema |  | ||||||
| from rest_framework.filters import SearchFilter |  | ||||||
| from rest_framework.request import Request |  | ||||||
| from structlog.stdlib import get_logger |  | ||||||
|  |  | ||||||
| from authentik.enterprise.search.fields import JSONSearchField |  | ||||||
|  |  | ||||||
| LOGGER = get_logger() |  | ||||||
| AUTOCOMPLETE_COMPONENT_NAME = "Autocomplete" |  | ||||||
| AUTOCOMPLETE_SCHEMA = { |  | ||||||
|     "type": "object", |  | ||||||
|     "additionalProperties": {}, |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class BaseSchema(DjangoQLSchema): |  | ||||||
|     """Base Schema which deals with JSON Fields""" |  | ||||||
|  |  | ||||||
|     def resolve_name(self, name: Name): |  | ||||||
|         model = self.model_label(self.current_model) |  | ||||||
|         root_field = name.parts[0] |  | ||||||
|         field = self.models[model].get(root_field) |  | ||||||
|         # If the query goes into a JSON field, return the root |  | ||||||
|         # field as the JSON field will do the rest |  | ||||||
|         if isinstance(field, JSONSearchField): |  | ||||||
|             # This is a workaround; build_filter will remove the right-most |  | ||||||
|             # entry in the path as that is intended to be the same as the field |  | ||||||
|             # however for JSON that is not the case |  | ||||||
|             if name.parts[-1] != root_field: |  | ||||||
|                 name.parts.append(root_field) |  | ||||||
|             return field |  | ||||||
|         return super().resolve_name(name) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class QLSearch(SearchFilter): |  | ||||||
|     """rest_framework search filter which uses DjangoQL""" |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def enabled(self): |  | ||||||
|         return apps.get_app_config("authentik_enterprise").enabled() |  | ||||||
|  |  | ||||||
|     def get_search_terms(self, request) -> str: |  | ||||||
|         """ |  | ||||||
|         Search terms are set by a ?search=... query parameter, |  | ||||||
|         and may be comma and/or whitespace delimited. |  | ||||||
|         """ |  | ||||||
|         params = request.query_params.get(self.search_param, "") |  | ||||||
|         params = params.replace("\x00", "")  # strip null characters |  | ||||||
|         return params |  | ||||||
|  |  | ||||||
|     def get_schema(self, request: Request, view) -> BaseSchema: |  | ||||||
|         ql_fields = [] |  | ||||||
|         if hasattr(view, "get_ql_fields"): |  | ||||||
|             ql_fields = view.get_ql_fields() |  | ||||||
|  |  | ||||||
|         class InlineSchema(BaseSchema): |  | ||||||
|             def get_fields(self, model): |  | ||||||
|                 return ql_fields or [] |  | ||||||
|  |  | ||||||
|         return InlineSchema |  | ||||||
|  |  | ||||||
|     def filter_queryset(self, request: Request, queryset: QuerySet, view) -> QuerySet: |  | ||||||
|         search_query = self.get_search_terms(request) |  | ||||||
|         schema = self.get_schema(request, view) |  | ||||||
|         if len(search_query) == 0 or not self.enabled: |  | ||||||
|             return super().filter_queryset(request, queryset, view) |  | ||||||
|         try: |  | ||||||
|             return apply_search(queryset, search_query, schema=schema) |  | ||||||
|         except DjangoQLError as exc: |  | ||||||
|             LOGGER.debug("Failed to parse search expression", exc=exc) |  | ||||||
|             return super().filter_queryset(request, queryset, view) |  | ||||||
| @ -1,29 +0,0 @@ | |||||||
| from djangoql.serializers import DjangoQLSchemaSerializer |  | ||||||
| from drf_spectacular.generators import SchemaGenerator |  | ||||||
|  |  | ||||||
| from authentik.api.schema import create_component |  | ||||||
| from authentik.enterprise.search.fields import JSONSearchField |  | ||||||
| from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AKQLSchemaSerializer(DjangoQLSchemaSerializer): |  | ||||||
|     def serialize(self, schema): |  | ||||||
|         serialization = super().serialize(schema) |  | ||||||
|         for _, fields in schema.models.items(): |  | ||||||
|             for _, field in fields.items(): |  | ||||||
|                 if not isinstance(field, JSONSearchField): |  | ||||||
|                     continue |  | ||||||
|                 serialization["models"].update(field.get_nested_options()) |  | ||||||
|         return serialization |  | ||||||
|  |  | ||||||
|     def serialize_field(self, field): |  | ||||||
|         result = super().serialize_field(field) |  | ||||||
|         if isinstance(field, JSONSearchField): |  | ||||||
|             result["relation"] = field.relation() |  | ||||||
|         return result |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def postprocess_schema_search_autocomplete(result, generator: SchemaGenerator, **kwargs): |  | ||||||
|     create_component(generator, AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA) |  | ||||||
|  |  | ||||||
|     return result |  | ||||||
| @ -1,17 +0,0 @@ | |||||||
| SPECTACULAR_SETTINGS = { |  | ||||||
|     "POSTPROCESSING_HOOKS": [ |  | ||||||
|         "authentik.api.schema.postprocess_schema_responses", |  | ||||||
|         "authentik.enterprise.search.schema.postprocess_schema_search_autocomplete", |  | ||||||
|         "drf_spectacular.hooks.postprocess_schema_enums", |  | ||||||
|     ], |  | ||||||
| } |  | ||||||
|  |  | ||||||
| REST_FRAMEWORK = { |  | ||||||
|     "DEFAULT_PAGINATION_CLASS": "authentik.enterprise.search.pagination.AutocompletePagination", |  | ||||||
|     "DEFAULT_FILTER_BACKENDS": [ |  | ||||||
|         "authentik.enterprise.search.ql.QLSearch", |  | ||||||
|         "authentik.rbac.filters.ObjectFilter", |  | ||||||
|         "django_filters.rest_framework.DjangoFilterBackend", |  | ||||||
|         "rest_framework.filters.OrderingFilter", |  | ||||||
|     ], |  | ||||||
| } |  | ||||||
| @ -1,78 +0,0 @@ | |||||||
| from json import loads |  | ||||||
| from unittest.mock import PropertyMock, patch |  | ||||||
| from urllib.parse import urlencode |  | ||||||
|  |  | ||||||
| from django.urls import reverse |  | ||||||
| from rest_framework.test import APITestCase |  | ||||||
|  |  | ||||||
| from authentik.core.tests.utils import create_test_admin_user |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @patch( |  | ||||||
|     "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", |  | ||||||
|     PropertyMock(return_value=True), |  | ||||||
| ) |  | ||||||
| class QLTest(APITestCase): |  | ||||||
|  |  | ||||||
|     def setUp(self): |  | ||||||
|         self.user = create_test_admin_user() |  | ||||||
|         # ensure we have more than 1 user |  | ||||||
|         create_test_admin_user() |  | ||||||
|  |  | ||||||
|     def test_search(self): |  | ||||||
|         """Test simple search query""" |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|         query = f'username = "{self.user.username}"' |  | ||||||
|         res = self.client.get( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_api:user-list", |  | ||||||
|             ) |  | ||||||
|             + f"?{urlencode({"search": query})}" |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(res.status_code, 200) |  | ||||||
|         content = loads(res.content) |  | ||||||
|         self.assertEqual(content["pagination"]["count"], 1) |  | ||||||
|         self.assertEqual(content["results"][0]["username"], self.user.username) |  | ||||||
|  |  | ||||||
|     def test_no_search(self): |  | ||||||
|         """Ensure works with no search query""" |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|         res = self.client.get( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_api:user-list", |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(res.status_code, 200) |  | ||||||
|         content = loads(res.content) |  | ||||||
|         self.assertNotEqual(content["pagination"]["count"], 1) |  | ||||||
|  |  | ||||||
|     def test_search_no_ql(self): |  | ||||||
|         """Test simple search query (no QL)""" |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|         res = self.client.get( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_api:user-list", |  | ||||||
|             ) |  | ||||||
|             + f"?{urlencode({"search": self.user.username})}" |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(res.status_code, 200) |  | ||||||
|         content = loads(res.content) |  | ||||||
|         self.assertGreaterEqual(content["pagination"]["count"], 1) |  | ||||||
|         self.assertEqual(content["results"][0]["username"], self.user.username) |  | ||||||
|  |  | ||||||
|     def test_search_json(self): |  | ||||||
|         """Test search query with a JSON attribute""" |  | ||||||
|         self.user.attributes = {"foo": {"bar": "baz"}} |  | ||||||
|         self.user.save() |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|         query = 'attributes.foo.bar = "baz"' |  | ||||||
|         res = self.client.get( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_api:user-list", |  | ||||||
|             ) |  | ||||||
|             + f"?{urlencode({"search": query})}" |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(res.status_code, 200) |  | ||||||
|         content = loads(res.content) |  | ||||||
|         self.assertEqual(content["pagination"]["count"], 1) |  | ||||||
|         self.assertEqual(content["results"][0]["username"], self.user.username) |  | ||||||
| @ -15,11 +15,9 @@ CELERY_BEAT_SCHEDULE = { | |||||||
| TENANT_APPS = [ | TENANT_APPS = [ | ||||||
|     "authentik.enterprise.audit", |     "authentik.enterprise.audit", | ||||||
|     "authentik.enterprise.policies.unique_password", |     "authentik.enterprise.policies.unique_password", | ||||||
|     "authentik.enterprise.providers.apple_psso", |  | ||||||
|     "authentik.enterprise.providers.google_workspace", |     "authentik.enterprise.providers.google_workspace", | ||||||
|     "authentik.enterprise.providers.microsoft_entra", |     "authentik.enterprise.providers.microsoft_entra", | ||||||
|     "authentik.enterprise.providers.ssf", |     "authentik.enterprise.providers.ssf", | ||||||
|     "authentik.enterprise.search", |  | ||||||
|     "authentik.enterprise.stages.authenticator_endpoint_gdtc", |     "authentik.enterprise.stages.authenticator_endpoint_gdtc", | ||||||
|     "authentik.enterprise.stages.mtls", |     "authentik.enterprise.stages.mtls", | ||||||
|     "authentik.enterprise.stages.source", |     "authentik.enterprise.stages.source", | ||||||
|  | |||||||
| @ -97,7 +97,6 @@ class SourceStageFinal(StageView): | |||||||
|         token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN) |         token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN) | ||||||
|         self.logger.info("Replacing source flow with overridden flow", flow=token.flow.slug) |         self.logger.info("Replacing source flow with overridden flow", flow=token.flow.slug) | ||||||
|         plan = token.plan |         plan = token.plan | ||||||
|         plan.context.update(self.executor.plan.context) |  | ||||||
|         plan.context[PLAN_CONTEXT_IS_RESTORED] = token |         plan.context[PLAN_CONTEXT_IS_RESTORED] = token | ||||||
|         response = plan.to_redirect(self.request, token.flow) |         response = plan.to_redirect(self.request, token.flow) | ||||||
|         token.delete() |         token.delete() | ||||||
|  | |||||||
| @ -90,17 +90,14 @@ class TestSourceStage(FlowTestCase): | |||||||
|         plan: FlowPlan = session[SESSION_KEY_PLAN] |         plan: FlowPlan = session[SESSION_KEY_PLAN] | ||||||
|         plan.insert_stage(in_memory_stage(SourceStageFinal), index=0) |         plan.insert_stage(in_memory_stage(SourceStageFinal), index=0) | ||||||
|         plan.context[PLAN_CONTEXT_IS_RESTORED] = flow_token |         plan.context[PLAN_CONTEXT_IS_RESTORED] = flow_token | ||||||
|         plan.context["foo"] = "bar" |  | ||||||
|         session[SESSION_KEY_PLAN] = plan |         session[SESSION_KEY_PLAN] = plan | ||||||
|         session.save() |         session.save() | ||||||
|  |  | ||||||
|         # Pretend we've just returned from the source |         # Pretend we've just returned from the source | ||||||
|         with self.assertFlowFinishes() as ff: |         response = self.client.get( | ||||||
|             response = self.client.get( |             reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True | ||||||
|                 reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True |         ) | ||||||
|             ) |         self.assertEqual(response.status_code, 200) | ||||||
|             self.assertEqual(response.status_code, 200) |         self.assertStageRedirects( | ||||||
|             self.assertStageRedirects( |             response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug}) | ||||||
|                 response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug}) |         ) | ||||||
|             ) |  | ||||||
|         self.assertEqual(ff().context["foo"], "bar") |  | ||||||
|  | |||||||
| @ -132,22 +132,6 @@ class EventViewSet(ModelViewSet): | |||||||
|     ] |     ] | ||||||
|     filterset_class = EventsFilter |     filterset_class = EventsFilter | ||||||
|  |  | ||||||
|     def get_ql_fields(self): |  | ||||||
|         from djangoql.schema import DateTimeField, StrField |  | ||||||
|  |  | ||||||
|         from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField |  | ||||||
|  |  | ||||||
|         return [ |  | ||||||
|             ChoiceSearchField(Event, "action"), |  | ||||||
|             StrField(Event, "event_uuid"), |  | ||||||
|             StrField(Event, "app", suggest_options=True), |  | ||||||
|             StrField(Event, "client_ip"), |  | ||||||
|             JSONSearchField(Event, "user", suggest_nested=False), |  | ||||||
|             JSONSearchField(Event, "brand", suggest_nested=False), |  | ||||||
|             JSONSearchField(Event, "context", suggest_nested=False), |  | ||||||
|             DateTimeField(Event, "created", suggest_options=True), |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|     @extend_schema( |     @extend_schema( | ||||||
|         methods=["GET"], |         methods=["GET"], | ||||||
|         responses={200: EventTopPerUserSerializer(many=True)}, |         responses={200: EventTopPerUserSerializer(many=True)}, | ||||||
|  | |||||||
| @ -11,7 +11,7 @@ from authentik.events.models import NotificationRule | |||||||
| class NotificationRuleSerializer(ModelSerializer): | class NotificationRuleSerializer(ModelSerializer): | ||||||
|     """NotificationRule Serializer""" |     """NotificationRule Serializer""" | ||||||
|  |  | ||||||
|     destination_group_obj = GroupSerializer(read_only=True, source="destination_group") |     group_obj = GroupSerializer(read_only=True, source="group") | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         model = NotificationRule |         model = NotificationRule | ||||||
| @ -20,9 +20,8 @@ class NotificationRuleSerializer(ModelSerializer): | |||||||
|             "name", |             "name", | ||||||
|             "transports", |             "transports", | ||||||
|             "severity", |             "severity", | ||||||
|             "destination_group", |             "group", | ||||||
|             "destination_group_obj", |             "group_obj", | ||||||
|             "destination_event_user", |  | ||||||
|         ] |         ] | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -31,6 +30,6 @@ class NotificationRuleViewSet(UsedByMixin, ModelViewSet): | |||||||
|  |  | ||||||
|     queryset = NotificationRule.objects.all() |     queryset = NotificationRule.objects.all() | ||||||
|     serializer_class = NotificationRuleSerializer |     serializer_class = NotificationRuleSerializer | ||||||
|     filterset_fields = ["name", "severity", "destination_group__name"] |     filterset_fields = ["name", "severity", "group__name"] | ||||||
|     ordering = ["name"] |     ordering = ["name"] | ||||||
|     search_fields = ["name", "destination_group__name"] |     search_fields = ["name", "group__name"] | ||||||
|  | |||||||
| @ -15,13 +15,13 @@ class MMDBContextProcessor(EventContextProcessor): | |||||||
|         self.reader: Reader | None = None |         self.reader: Reader | None = None | ||||||
|         self._last_mtime: float = 0.0 |         self._last_mtime: float = 0.0 | ||||||
|         self.logger = get_logger() |         self.logger = get_logger() | ||||||
|         self.load() |         self.open() | ||||||
|  |  | ||||||
|     def path(self) -> str | None: |     def path(self) -> str | None: | ||||||
|         """Get the path to the MMDB file to load""" |         """Get the path to the MMDB file to load""" | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def load(self): |     def open(self): | ||||||
|         """Get GeoIP Reader, if configured, otherwise none""" |         """Get GeoIP Reader, if configured, otherwise none""" | ||||||
|         path = self.path() |         path = self.path() | ||||||
|         if path == "" or not path: |         if path == "" or not path: | ||||||
| @ -44,7 +44,7 @@ class MMDBContextProcessor(EventContextProcessor): | |||||||
|             diff = self._last_mtime < mtime |             diff = self._last_mtime < mtime | ||||||
|             if diff > 0: |             if diff > 0: | ||||||
|                 self.logger.info("Found new MMDB Database, reopening", diff=diff, path=path) |                 self.logger.info("Found new MMDB Database, reopening", diff=diff, path=path) | ||||||
|                 self.load() |                 self.open() | ||||||
|         except OSError as exc: |         except OSError as exc: | ||||||
|             self.logger.warning("Failed to check MMDB age", exc=exc) |             self.logger.warning("Failed to check MMDB age", exc=exc) | ||||||
|  |  | ||||||
|  | |||||||
| @ -19,7 +19,7 @@ from authentik.blueprints.v1.importer import excluded_models | |||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.events.models import Event, EventAction, Notification | from authentik.events.models import Event, EventAction, Notification | ||||||
| from authentik.events.utils import model_to_dict | from authentik.events.utils import model_to_dict | ||||||
| from authentik.lib.sentry import should_ignore_exception | from authentik.lib.sentry import before_send | ||||||
| from authentik.lib.utils.errors import exception_to_string | from authentik.lib.utils.errors import exception_to_string | ||||||
| from authentik.stages.authenticator_static.models import StaticToken | from authentik.stages.authenticator_static.models import StaticToken | ||||||
|  |  | ||||||
| @ -173,7 +173,7 @@ class AuditMiddleware: | |||||||
|                 message=exception_to_string(exception), |                 message=exception_to_string(exception), | ||||||
|             ) |             ) | ||||||
|             thread.run() |             thread.run() | ||||||
|         elif not should_ignore_exception(exception): |         elif before_send({}, {"exc_info": (None, exception, None)}) is not None: | ||||||
|             thread = EventNewThread( |             thread = EventNewThread( | ||||||
|                 EventAction.SYSTEM_EXCEPTION, |                 EventAction.SYSTEM_EXCEPTION, | ||||||
|                 request, |                 request, | ||||||
|  | |||||||
| @ -1,26 +0,0 @@ | |||||||
| # Generated by Django 5.1.11 on 2025-06-16 23:21 |  | ||||||
|  |  | ||||||
| from django.db import migrations, models |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ("authentik_events", "0009_remove_notificationtransport_webhook_mapping_and_more"), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.RenameField( |  | ||||||
|             model_name="notificationrule", |  | ||||||
|             old_name="group", |  | ||||||
|             new_name="destination_group", |  | ||||||
|         ), |  | ||||||
|         migrations.AddField( |  | ||||||
|             model_name="notificationrule", |  | ||||||
|             name="destination_event_user", |  | ||||||
|             field=models.BooleanField( |  | ||||||
|                 default=False, |  | ||||||
|                 help_text="When enabled, notification will be sent to user the user that triggered the event.When destination_group is configured, notification is sent to both.", |  | ||||||
|             ), |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -1,12 +1,10 @@ | |||||||
| """authentik events models""" | """authentik events models""" | ||||||
|  |  | ||||||
| from collections.abc import Generator |  | ||||||
| from datetime import timedelta | from datetime import timedelta | ||||||
| from difflib import get_close_matches | from difflib import get_close_matches | ||||||
| from functools import lru_cache | from functools import lru_cache | ||||||
| from inspect import currentframe | from inspect import currentframe | ||||||
| from smtplib import SMTPException | from smtplib import SMTPException | ||||||
| from typing import Any |  | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from django.apps import apps | from django.apps import apps | ||||||
| @ -193,32 +191,17 @@ class Event(SerializerModel, ExpiringModel): | |||||||
|             brand: Brand = request.brand |             brand: Brand = request.brand | ||||||
|             self.brand = sanitize_dict(model_to_dict(brand)) |             self.brand = sanitize_dict(model_to_dict(brand)) | ||||||
|         if hasattr(request, "user"): |         if hasattr(request, "user"): | ||||||
|             self.user = get_user(request.user) |             original_user = None | ||||||
|  |             if hasattr(request, "session"): | ||||||
|  |                 original_user = request.session.get(SESSION_KEY_IMPERSONATE_ORIGINAL_USER, None) | ||||||
|  |             self.user = get_user(request.user, original_user) | ||||||
|         if user: |         if user: | ||||||
|             self.user = get_user(user) |             self.user = get_user(user) | ||||||
|  |         # Check if we're currently impersonating, and add that user | ||||||
|         if hasattr(request, "session"): |         if hasattr(request, "session"): | ||||||
|             from authentik.flows.views.executor import SESSION_KEY_PLAN |  | ||||||
|  |  | ||||||
|             # Check if we're currently impersonating, and add that user |  | ||||||
|             if SESSION_KEY_IMPERSONATE_ORIGINAL_USER in request.session: |             if SESSION_KEY_IMPERSONATE_ORIGINAL_USER in request.session: | ||||||
|                 self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER]) |                 self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER]) | ||||||
|                 self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER]) |                 self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER]) | ||||||
|             # Special case for events that happen during a flow, the user might not be authenticated |  | ||||||
|             # yet but is a pending user instead |  | ||||||
|             if SESSION_KEY_PLAN in request.session: |  | ||||||
|                 from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan |  | ||||||
|  |  | ||||||
|                 plan: FlowPlan = request.session[SESSION_KEY_PLAN] |  | ||||||
|                 pending_user = plan.context.get(PLAN_CONTEXT_PENDING_USER, None) |  | ||||||
|                 # Only save `authenticated_as` if there's a different pending user in the flow |  | ||||||
|                 # than the user that is authenticated |  | ||||||
|                 if pending_user and ( |  | ||||||
|                     (pending_user.pk and pending_user.pk != self.user.get("pk")) |  | ||||||
|                     or (not pending_user.pk) |  | ||||||
|                 ): |  | ||||||
|                     orig_user = self.user.copy() |  | ||||||
|  |  | ||||||
|                     self.user = {"authenticated_as": orig_user, **get_user(pending_user)} |  | ||||||
|         # User 255.255.255.255 as fallback if IP cannot be determined |         # User 255.255.255.255 as fallback if IP cannot be determined | ||||||
|         self.client_ip = ClientIPMiddleware.get_client_ip(request) |         self.client_ip = ClientIPMiddleware.get_client_ip(request) | ||||||
|         # Enrich event data |         # Enrich event data | ||||||
| @ -564,7 +547,7 @@ class NotificationRule(SerializerModel, PolicyBindingModel): | |||||||
|         default=NotificationSeverity.NOTICE, |         default=NotificationSeverity.NOTICE, | ||||||
|         help_text=_("Controls which severity level the created notifications will have."), |         help_text=_("Controls which severity level the created notifications will have."), | ||||||
|     ) |     ) | ||||||
|     destination_group = models.ForeignKey( |     group = models.ForeignKey( | ||||||
|         Group, |         Group, | ||||||
|         help_text=_( |         help_text=_( | ||||||
|             "Define which group of users this notification should be sent and shown to. " |             "Define which group of users this notification should be sent and shown to. " | ||||||
| @ -574,19 +557,6 @@ class NotificationRule(SerializerModel, PolicyBindingModel): | |||||||
|         blank=True, |         blank=True, | ||||||
|         on_delete=models.SET_NULL, |         on_delete=models.SET_NULL, | ||||||
|     ) |     ) | ||||||
|     destination_event_user = models.BooleanField( |  | ||||||
|         default=False, |  | ||||||
|         help_text=_( |  | ||||||
|             "When enabled, notification will be sent to user the user that triggered the event." |  | ||||||
|             "When destination_group is configured, notification is sent to both." |  | ||||||
|         ), |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     def destination_users(self, event: Event) -> Generator[User, Any]: |  | ||||||
|         if self.destination_event_user and event.user.get("pk"): |  | ||||||
|             yield User(pk=event.user.get("pk")) |  | ||||||
|         if self.destination_group: |  | ||||||
|             yield from self.destination_group.users.all() |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> type[Serializer]: |     def serializer(self) -> type[Serializer]: | ||||||
|  | |||||||
| @ -68,10 +68,14 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): | |||||||
|     if not result.passing: |     if not result.passing: | ||||||
|         return |         return | ||||||
|  |  | ||||||
|  |     if not trigger.group: | ||||||
|  |         LOGGER.debug("e(trigger): trigger has no group", trigger=trigger) | ||||||
|  |         return | ||||||
|  |  | ||||||
|     LOGGER.debug("e(trigger): event trigger matched", trigger=trigger) |     LOGGER.debug("e(trigger): event trigger matched", trigger=trigger) | ||||||
|     # Create the notification objects |     # Create the notification objects | ||||||
|     for transport in trigger.transports.all(): |     for transport in trigger.transports.all(): | ||||||
|         for user in trigger.destination_users(event): |         for user in trigger.group.users.all(): | ||||||
|             LOGGER.debug("created notification") |             LOGGER.debug("created notification") | ||||||
|             notification_transport.apply_async( |             notification_transport.apply_async( | ||||||
|                 args=[ |                 args=[ | ||||||
|  | |||||||
| @ -2,9 +2,7 @@ | |||||||
|  |  | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
|  |  | ||||||
| from authentik.events.context_processors.base import get_context_processors |  | ||||||
| from authentik.events.context_processors.geoip import GeoIPContextProcessor | from authentik.events.context_processors.geoip import GeoIPContextProcessor | ||||||
| from authentik.events.models import Event, EventAction |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestGeoIP(TestCase): | class TestGeoIP(TestCase): | ||||||
| @ -15,7 +13,8 @@ class TestGeoIP(TestCase): | |||||||
|  |  | ||||||
|     def test_simple(self): |     def test_simple(self): | ||||||
|         """Test simple city wrapper""" |         """Test simple city wrapper""" | ||||||
|         # IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json |         # IPs from | ||||||
|  |         # https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             self.reader.city_dict("2.125.160.216"), |             self.reader.city_dict("2.125.160.216"), | ||||||
|             { |             { | ||||||
| @ -26,12 +25,3 @@ class TestGeoIP(TestCase): | |||||||
|                 "long": -1.25, |                 "long": -1.25, | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def test_special_chars(self): |  | ||||||
|         """Test city name with special characters""" |  | ||||||
|         # IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json |  | ||||||
|         event = Event.new(EventAction.LOGIN) |  | ||||||
|         event.client_ip = "89.160.20.112" |  | ||||||
|         for processor in get_context_processors(): |  | ||||||
|             processor.enrich_event(event) |  | ||||||
|         event.save() |  | ||||||
|  | |||||||
| @ -8,11 +8,9 @@ from django.views.debug import SafeExceptionReporterFilter | |||||||
| from guardian.shortcuts import get_anonymous_user | from guardian.shortcuts import get_anonymous_user | ||||||
|  |  | ||||||
| from authentik.brands.models import Brand | from authentik.brands.models import Brand | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group | ||||||
| from authentik.core.tests.utils import create_test_user |  | ||||||
| from authentik.events.models import Event | from authentik.events.models import Event | ||||||
| from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan | from authentik.flows.views.executor import QS_QUERY | ||||||
| from authentik.flows.views.executor import QS_QUERY, SESSION_KEY_PLAN |  | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.policies.dummy.models import DummyPolicy | from authentik.policies.dummy.models import DummyPolicy | ||||||
|  |  | ||||||
| @ -118,92 +116,3 @@ class TestEvents(TestCase): | |||||||
|                 "pk": brand.pk.hex, |                 "pk": brand.pk.hex, | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def test_from_http_flow_pending_user(self): |  | ||||||
|         """Test request from flow request with a pending user""" |  | ||||||
|         user = create_test_user() |  | ||||||
|  |  | ||||||
|         session = self.client.session |  | ||||||
|         plan = FlowPlan(generate_id()) |  | ||||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = user |  | ||||||
|         session[SESSION_KEY_PLAN] = plan |  | ||||||
|         session.save() |  | ||||||
|  |  | ||||||
|         request = self.factory.get("/") |  | ||||||
|         request.session = session |  | ||||||
|         request.user = user |  | ||||||
|  |  | ||||||
|         event = Event.new("unittest").from_http(request) |  | ||||||
|         self.assertEqual( |  | ||||||
|             event.user, |  | ||||||
|             { |  | ||||||
|                 "email": user.email, |  | ||||||
|                 "pk": user.pk, |  | ||||||
|                 "username": user.username, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_from_http_flow_pending_user_anon(self): |  | ||||||
|         """Test request from flow request with a pending user""" |  | ||||||
|         user = create_test_user() |  | ||||||
|         anon = get_anonymous_user() |  | ||||||
|  |  | ||||||
|         session = self.client.session |  | ||||||
|         plan = FlowPlan(generate_id()) |  | ||||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = user |  | ||||||
|         session[SESSION_KEY_PLAN] = plan |  | ||||||
|         session.save() |  | ||||||
|  |  | ||||||
|         request = self.factory.get("/") |  | ||||||
|         request.session = session |  | ||||||
|         request.user = anon |  | ||||||
|  |  | ||||||
|         event = Event.new("unittest").from_http(request) |  | ||||||
|         self.assertEqual( |  | ||||||
|             event.user, |  | ||||||
|             { |  | ||||||
|                 "authenticated_as": { |  | ||||||
|                     "pk": anon.pk, |  | ||||||
|                     "is_anonymous": True, |  | ||||||
|                     "username": "AnonymousUser", |  | ||||||
|                     "email": "", |  | ||||||
|                 }, |  | ||||||
|                 "email": user.email, |  | ||||||
|                 "pk": user.pk, |  | ||||||
|                 "username": user.username, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_from_http_flow_pending_user_fake(self): |  | ||||||
|         """Test request from flow request with a pending user""" |  | ||||||
|         user = User( |  | ||||||
|             username=generate_id(), |  | ||||||
|             email=generate_id(), |  | ||||||
|         ) |  | ||||||
|         anon = get_anonymous_user() |  | ||||||
|  |  | ||||||
|         session = self.client.session |  | ||||||
|         plan = FlowPlan(generate_id()) |  | ||||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = user |  | ||||||
|         session[SESSION_KEY_PLAN] = plan |  | ||||||
|         session.save() |  | ||||||
|  |  | ||||||
|         request = self.factory.get("/") |  | ||||||
|         request.session = session |  | ||||||
|         request.user = anon |  | ||||||
|  |  | ||||||
|         event = Event.new("unittest").from_http(request) |  | ||||||
|         self.assertEqual( |  | ||||||
|             event.user, |  | ||||||
|             { |  | ||||||
|                 "authenticated_as": { |  | ||||||
|                     "pk": anon.pk, |  | ||||||
|                     "is_anonymous": True, |  | ||||||
|                     "username": "AnonymousUser", |  | ||||||
|                     "email": "", |  | ||||||
|                 }, |  | ||||||
|                 "email": user.email, |  | ||||||
|                 "pk": user.pk, |  | ||||||
|                 "username": user.username, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|  | |||||||
| @ -6,7 +6,6 @@ from django.urls import reverse | |||||||
| from rest_framework.test import APITestCase | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.core.tests.utils import create_test_user |  | ||||||
| from authentik.events.models import ( | from authentik.events.models import ( | ||||||
|     Event, |     Event, | ||||||
|     EventAction, |     EventAction, | ||||||
| @ -35,7 +34,7 @@ class TestEventsNotifications(APITestCase): | |||||||
|     def test_trigger_empty(self): |     def test_trigger_empty(self): | ||||||
|         """Test trigger without any policies attached""" |         """Test trigger without any policies attached""" | ||||||
|         transport = NotificationTransport.objects.create(name=generate_id()) |         transport = NotificationTransport.objects.create(name=generate_id()) | ||||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) |         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||||
|         trigger.transports.add(transport) |         trigger.transports.add(transport) | ||||||
|         trigger.save() |         trigger.save() | ||||||
|  |  | ||||||
| @ -47,7 +46,7 @@ class TestEventsNotifications(APITestCase): | |||||||
|     def test_trigger_single(self): |     def test_trigger_single(self): | ||||||
|         """Test simple transport triggering""" |         """Test simple transport triggering""" | ||||||
|         transport = NotificationTransport.objects.create(name=generate_id()) |         transport = NotificationTransport.objects.create(name=generate_id()) | ||||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) |         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||||
|         trigger.transports.add(transport) |         trigger.transports.add(transport) | ||||||
|         trigger.save() |         trigger.save() | ||||||
|         matcher = EventMatcherPolicy.objects.create( |         matcher = EventMatcherPolicy.objects.create( | ||||||
| @ -60,25 +59,6 @@ class TestEventsNotifications(APITestCase): | |||||||
|             Event.new(EventAction.CUSTOM_PREFIX).save() |             Event.new(EventAction.CUSTOM_PREFIX).save() | ||||||
|         self.assertEqual(execute_mock.call_count, 1) |         self.assertEqual(execute_mock.call_count, 1) | ||||||
|  |  | ||||||
|     def test_trigger_event_user(self): |  | ||||||
|         """Test trigger with event user""" |  | ||||||
|         user = create_test_user() |  | ||||||
|         transport = NotificationTransport.objects.create(name=generate_id()) |  | ||||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_event_user=True) |  | ||||||
|         trigger.transports.add(transport) |  | ||||||
|         trigger.save() |  | ||||||
|         matcher = EventMatcherPolicy.objects.create( |  | ||||||
|             name="matcher", action=EventAction.CUSTOM_PREFIX |  | ||||||
|         ) |  | ||||||
|         PolicyBinding.objects.create(target=trigger, policy=matcher, order=0) |  | ||||||
|  |  | ||||||
|         execute_mock = MagicMock() |  | ||||||
|         with patch("authentik.events.models.NotificationTransport.send", execute_mock): |  | ||||||
|             Event.new(EventAction.CUSTOM_PREFIX).set_user(user).save() |  | ||||||
|         self.assertEqual(execute_mock.call_count, 1) |  | ||||||
|         notification: Notification = execute_mock.call_args[0][0] |  | ||||||
|         self.assertEqual(notification.user, user) |  | ||||||
|  |  | ||||||
|     def test_trigger_no_group(self): |     def test_trigger_no_group(self): | ||||||
|         """Test trigger without group""" |         """Test trigger without group""" | ||||||
|         trigger = NotificationRule.objects.create(name=generate_id()) |         trigger = NotificationRule.objects.create(name=generate_id()) | ||||||
| @ -96,7 +76,7 @@ class TestEventsNotifications(APITestCase): | |||||||
|         """Test Policy error which would cause recursion""" |         """Test Policy error which would cause recursion""" | ||||||
|         transport = NotificationTransport.objects.create(name=generate_id()) |         transport = NotificationTransport.objects.create(name=generate_id()) | ||||||
|         NotificationRule.objects.filter(name__startswith="default").delete() |         NotificationRule.objects.filter(name__startswith="default").delete() | ||||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) |         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||||
|         trigger.transports.add(transport) |         trigger.transports.add(transport) | ||||||
|         trigger.save() |         trigger.save() | ||||||
|         matcher = EventMatcherPolicy.objects.create( |         matcher = EventMatcherPolicy.objects.create( | ||||||
| @ -119,7 +99,7 @@ class TestEventsNotifications(APITestCase): | |||||||
|  |  | ||||||
|         transport = NotificationTransport.objects.create(name=generate_id(), send_once=True) |         transport = NotificationTransport.objects.create(name=generate_id(), send_once=True) | ||||||
|         NotificationRule.objects.filter(name__startswith="default").delete() |         NotificationRule.objects.filter(name__startswith="default").delete() | ||||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) |         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||||
|         trigger.transports.add(transport) |         trigger.transports.add(transport) | ||||||
|         trigger.save() |         trigger.save() | ||||||
|         matcher = EventMatcherPolicy.objects.create( |         matcher = EventMatcherPolicy.objects.create( | ||||||
| @ -143,7 +123,7 @@ class TestEventsNotifications(APITestCase): | |||||||
|             name=generate_id(), webhook_mapping_body=mapping, mode=TransportMode.LOCAL |             name=generate_id(), webhook_mapping_body=mapping, mode=TransportMode.LOCAL | ||||||
|         ) |         ) | ||||||
|         NotificationRule.objects.filter(name__startswith="default").delete() |         NotificationRule.objects.filter(name__startswith="default").delete() | ||||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) |         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||||
|         trigger.transports.add(transport) |         trigger.transports.add(transport) | ||||||
|         matcher = EventMatcherPolicy.objects.create( |         matcher = EventMatcherPolicy.objects.create( | ||||||
|             name="matcher", action=EventAction.CUSTOM_PREFIX |             name="matcher", action=EventAction.CUSTOM_PREFIX | ||||||
|  | |||||||
| @ -74,8 +74,8 @@ def model_to_dict(model: Model) -> dict[str, Any]: | |||||||
|     } |     } | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_user(user: User | AnonymousUser) -> dict[str, Any]: | def get_user(user: User | AnonymousUser, original_user: User | None = None) -> dict[str, Any]: | ||||||
|     """Convert user object to dictionary""" |     """Convert user object to dictionary, optionally including the original user""" | ||||||
|     if isinstance(user, AnonymousUser): |     if isinstance(user, AnonymousUser): | ||||||
|         try: |         try: | ||||||
|             user = get_anonymous_user() |             user = get_anonymous_user() | ||||||
| @ -88,6 +88,10 @@ def get_user(user: User | AnonymousUser) -> dict[str, Any]: | |||||||
|     } |     } | ||||||
|     if user.username == settings.ANONYMOUS_USER_NAME: |     if user.username == settings.ANONYMOUS_USER_NAME: | ||||||
|         user_data["is_anonymous"] = True |         user_data["is_anonymous"] = True | ||||||
|  |     if original_user: | ||||||
|  |         original_data = get_user(original_user) | ||||||
|  |         original_data["on_behalf_of"] = user_data | ||||||
|  |         return original_data | ||||||
|     return user_data |     return user_data | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -4,10 +4,8 @@ from unittest.mock import MagicMock, PropertyMock, patch | |||||||
| from urllib.parse import urlencode | from urllib.parse import urlencode | ||||||
|  |  | ||||||
| from django.http import HttpRequest, HttpResponse | from django.http import HttpRequest, HttpResponse | ||||||
| from django.test import override_settings |  | ||||||
| from django.test.client import RequestFactory | from django.test.client import RequestFactory | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from rest_framework.exceptions import ParseError |  | ||||||
|  |  | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.core.tests.utils import create_test_flow, create_test_user | from authentik.core.tests.utils import create_test_flow, create_test_user | ||||||
| @ -650,25 +648,3 @@ class TestFlowExecutor(FlowTestCase): | |||||||
|             self.assertStageResponse(response, flow, component="ak-stage-identification") |             self.assertStageResponse(response, flow, component="ak-stage-identification") | ||||||
|             response = self.client.post(exec_url, {"uid_field": user_other.username}, follow=True) |             response = self.client.post(exec_url, {"uid_field": user_other.username}, follow=True) | ||||||
|             self.assertStageResponse(response, flow, component="ak-stage-access-denied") |             self.assertStageResponse(response, flow, component="ak-stage-access-denied") | ||||||
|  |  | ||||||
|     @patch( |  | ||||||
|         "authentik.flows.views.executor.to_stage_response", |  | ||||||
|         TO_STAGE_RESPONSE_MOCK, |  | ||||||
|     ) |  | ||||||
|     def test_invalid_json(self): |  | ||||||
|         """Test invalid JSON body""" |  | ||||||
|         flow = create_test_flow() |  | ||||||
|         FlowStageBinding.objects.create( |  | ||||||
|             target=flow, stage=DummyStage.objects.create(name=generate_id()), order=0 |  | ||||||
|         ) |  | ||||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) |  | ||||||
|  |  | ||||||
|         with override_settings(TEST=False, DEBUG=False): |  | ||||||
|             self.client.logout() |  | ||||||
|             response = self.client.post(url, data="{", content_type="application/json") |  | ||||||
|             self.assertEqual(response.status_code, 200) |  | ||||||
|  |  | ||||||
|         with self.assertRaises(ParseError): |  | ||||||
|             self.client.logout() |  | ||||||
|             response = self.client.post(url, data="{", content_type="application/json") |  | ||||||
|             self.assertEqual(response.status_code, 200) |  | ||||||
|  | |||||||
| @ -55,7 +55,7 @@ from authentik.flows.planner import ( | |||||||
|     FlowPlanner, |     FlowPlanner, | ||||||
| ) | ) | ||||||
| from authentik.flows.stage import AccessDeniedStage, StageView | from authentik.flows.stage import AccessDeniedStage, StageView | ||||||
| from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception | from authentik.lib.sentry import SentryIgnoredException | ||||||
| from authentik.lib.utils.errors import exception_to_string | from authentik.lib.utils.errors import exception_to_string | ||||||
| from authentik.lib.utils.reflection import all_subclasses, class_to_path | from authentik.lib.utils.reflection import all_subclasses, class_to_path | ||||||
| from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs | from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs | ||||||
| @ -234,13 +234,12 @@ class FlowExecutorView(APIView): | |||||||
|         """Handle exception in stage execution""" |         """Handle exception in stage execution""" | ||||||
|         if settings.DEBUG or settings.TEST: |         if settings.DEBUG or settings.TEST: | ||||||
|             raise exc |             raise exc | ||||||
|  |         capture_exception(exc) | ||||||
|         self._logger.warning(exc) |         self._logger.warning(exc) | ||||||
|         if not should_ignore_exception(exc): |         Event.new( | ||||||
|             capture_exception(exc) |             action=EventAction.SYSTEM_EXCEPTION, | ||||||
|             Event.new( |             message=exception_to_string(exc), | ||||||
|                 action=EventAction.SYSTEM_EXCEPTION, |         ).from_http(self.request) | ||||||
|                 message=exception_to_string(exc), |  | ||||||
|             ).from_http(self.request) |  | ||||||
|         challenge = FlowErrorChallenge(self.request, exc) |         challenge = FlowErrorChallenge(self.request, exc) | ||||||
|         challenge.is_valid(raise_exception=True) |         challenge.is_valid(raise_exception=True) | ||||||
|         return to_stage_response(self.request, HttpChallengeResponse(challenge)) |         return to_stage_response(self.request, HttpChallengeResponse(challenge)) | ||||||
|  | |||||||
| @ -14,7 +14,6 @@ from django_redis.exceptions import ConnectionInterrupted | |||||||
| from docker.errors import DockerException | from docker.errors import DockerException | ||||||
| from h11 import LocalProtocolError | from h11 import LocalProtocolError | ||||||
| from ldap3.core.exceptions import LDAPException | from ldap3.core.exceptions import LDAPException | ||||||
| from psycopg.errors import Error |  | ||||||
| from redis.exceptions import ConnectionError as RedisConnectionError | from redis.exceptions import ConnectionError as RedisConnectionError | ||||||
| from redis.exceptions import RedisError, ResponseError | from redis.exceptions import RedisError, ResponseError | ||||||
| from rest_framework.exceptions import APIException | from rest_framework.exceptions import APIException | ||||||
| @ -45,49 +44,6 @@ class SentryIgnoredException(Exception): | |||||||
|     """Base Class for all errors that are suppressed, and not sent to sentry.""" |     """Base Class for all errors that are suppressed, and not sent to sentry.""" | ||||||
|  |  | ||||||
|  |  | ||||||
| ignored_classes = ( |  | ||||||
|     # Inbuilt types |  | ||||||
|     KeyboardInterrupt, |  | ||||||
|     ConnectionResetError, |  | ||||||
|     OSError, |  | ||||||
|     PermissionError, |  | ||||||
|     # Django Errors |  | ||||||
|     Error, |  | ||||||
|     ImproperlyConfigured, |  | ||||||
|     DatabaseError, |  | ||||||
|     OperationalError, |  | ||||||
|     InternalError, |  | ||||||
|     ProgrammingError, |  | ||||||
|     SuspiciousOperation, |  | ||||||
|     ValidationError, |  | ||||||
|     # Redis errors |  | ||||||
|     RedisConnectionError, |  | ||||||
|     ConnectionInterrupted, |  | ||||||
|     RedisError, |  | ||||||
|     ResponseError, |  | ||||||
|     # websocket errors |  | ||||||
|     ChannelFull, |  | ||||||
|     WebSocketException, |  | ||||||
|     LocalProtocolError, |  | ||||||
|     # rest_framework error |  | ||||||
|     APIException, |  | ||||||
|     # celery errors |  | ||||||
|     WorkerLostError, |  | ||||||
|     CeleryError, |  | ||||||
|     SoftTimeLimitExceeded, |  | ||||||
|     # custom baseclass |  | ||||||
|     SentryIgnoredException, |  | ||||||
|     # ldap errors |  | ||||||
|     LDAPException, |  | ||||||
|     # Docker errors |  | ||||||
|     DockerException, |  | ||||||
|     # End-user errors |  | ||||||
|     Http404, |  | ||||||
|     # AsyncIO |  | ||||||
|     CancelledError, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SentryTransport(HttpTransport): | class SentryTransport(HttpTransport): | ||||||
|     """Custom sentry transport with custom user-agent""" |     """Custom sentry transport with custom user-agent""" | ||||||
|  |  | ||||||
| @ -145,17 +101,56 @@ def traces_sampler(sampling_context: dict) -> float: | |||||||
|     return float(CONFIG.get("error_reporting.sample_rate", 0.1)) |     return float(CONFIG.get("error_reporting.sample_rate", 0.1)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def should_ignore_exception(exc: Exception) -> bool: |  | ||||||
|     """Check if an exception should be dropped""" |  | ||||||
|     return isinstance(exc, ignored_classes) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def before_send(event: dict, hint: dict) -> dict | None: | def before_send(event: dict, hint: dict) -> dict | None: | ||||||
|     """Check if error is database error, and ignore if so""" |     """Check if error is database error, and ignore if so""" | ||||||
|  |  | ||||||
|  |     from psycopg.errors import Error | ||||||
|  |  | ||||||
|  |     ignored_classes = ( | ||||||
|  |         # Inbuilt types | ||||||
|  |         KeyboardInterrupt, | ||||||
|  |         ConnectionResetError, | ||||||
|  |         OSError, | ||||||
|  |         PermissionError, | ||||||
|  |         # Django Errors | ||||||
|  |         Error, | ||||||
|  |         ImproperlyConfigured, | ||||||
|  |         DatabaseError, | ||||||
|  |         OperationalError, | ||||||
|  |         InternalError, | ||||||
|  |         ProgrammingError, | ||||||
|  |         SuspiciousOperation, | ||||||
|  |         ValidationError, | ||||||
|  |         # Redis errors | ||||||
|  |         RedisConnectionError, | ||||||
|  |         ConnectionInterrupted, | ||||||
|  |         RedisError, | ||||||
|  |         ResponseError, | ||||||
|  |         # websocket errors | ||||||
|  |         ChannelFull, | ||||||
|  |         WebSocketException, | ||||||
|  |         LocalProtocolError, | ||||||
|  |         # rest_framework error | ||||||
|  |         APIException, | ||||||
|  |         # celery errors | ||||||
|  |         WorkerLostError, | ||||||
|  |         CeleryError, | ||||||
|  |         SoftTimeLimitExceeded, | ||||||
|  |         # custom baseclass | ||||||
|  |         SentryIgnoredException, | ||||||
|  |         # ldap errors | ||||||
|  |         LDAPException, | ||||||
|  |         # Docker errors | ||||||
|  |         DockerException, | ||||||
|  |         # End-user errors | ||||||
|  |         Http404, | ||||||
|  |         # AsyncIO | ||||||
|  |         CancelledError, | ||||||
|  |     ) | ||||||
|     exc_value = None |     exc_value = None | ||||||
|     if "exc_info" in hint: |     if "exc_info" in hint: | ||||||
|         _, exc_value, _ = hint["exc_info"] |         _, exc_value, _ = hint["exc_info"] | ||||||
|         if should_ignore_exception(exc_value): |         if isinstance(exc_value, ignored_classes): | ||||||
|             LOGGER.debug("dropping exception", exc=exc_value) |             LOGGER.debug("dropping exception", exc=exc_value) | ||||||
|             return None |             return None | ||||||
|     if "logger" in event: |     if "logger" in event: | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
|  |  | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
|  |  | ||||||
| from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception | from authentik.lib.sentry import SentryIgnoredException, before_send | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestSentry(TestCase): | class TestSentry(TestCase): | ||||||
| @ -10,8 +10,8 @@ class TestSentry(TestCase): | |||||||
|  |  | ||||||
|     def test_error_not_sent(self): |     def test_error_not_sent(self): | ||||||
|         """Test SentryIgnoredError not sent""" |         """Test SentryIgnoredError not sent""" | ||||||
|         self.assertTrue(should_ignore_exception(SentryIgnoredException())) |         self.assertIsNone(before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)})) | ||||||
|  |  | ||||||
|     def test_error_sent(self): |     def test_error_sent(self): | ||||||
|         """Test error sent""" |         """Test error sent""" | ||||||
|         self.assertFalse(should_ignore_exception(ValueError())) |         self.assertEqual({}, before_send({}, {"exc_info": (0, ValueError(), 0)})) | ||||||
|  | |||||||
| @ -1,13 +1,15 @@ | |||||||
| """authentik outpost signals""" | """authentik outpost signals""" | ||||||
|  |  | ||||||
|  | from django.contrib.auth.signals import user_logged_out | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.db.models import Model | from django.db.models import Model | ||||||
| from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save | from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
|  | from django.http import HttpRequest | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.brands.models import Brand | from authentik.brands.models import Brand | ||||||
| from authentik.core.models import AuthenticatedSession, Provider | from authentik.core.models import AuthenticatedSession, Provider, User | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.lib.utils.reflection import class_to_path | from authentik.lib.utils.reflection import class_to_path | ||||||
| from authentik.outposts.models import Outpost, OutpostServiceConnection | from authentik.outposts.models import Outpost, OutpostServiceConnection | ||||||
| @ -80,6 +82,14 @@ def pre_delete_cleanup(sender, instance: Outpost, **_): | |||||||
|     outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) |     outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(user_logged_out) | ||||||
|  | def logout_revoke_direct(sender: type[User], request: HttpRequest, **_): | ||||||
|  |     """Catch logout by direct logout and forward to providers""" | ||||||
|  |     if not request.session or not request.session.session_key: | ||||||
|  |         return | ||||||
|  |     outpost_session_end.delay(request.session.session_key) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_delete, sender=AuthenticatedSession) | @receiver(pre_delete, sender=AuthenticatedSession) | ||||||
| def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | ||||||
|     """Catch logout by expiring sessions being deleted""" |     """Catch logout by expiring sessions being deleted""" | ||||||
|  | |||||||
| @ -15,6 +15,7 @@ class OAuth2Error(SentryIgnoredException): | |||||||
|  |  | ||||||
|     error: str |     error: str | ||||||
|     description: str |     description: str | ||||||
|  |     cause: str | None = None | ||||||
|  |  | ||||||
|     def create_dict(self): |     def create_dict(self): | ||||||
|         """Return error as dict for JSON Rendering""" |         """Return error as dict for JSON Rendering""" | ||||||
| @ -34,6 +35,10 @@ class OAuth2Error(SentryIgnoredException): | |||||||
|             **kwargs, |             **kwargs, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def with_cause(self, cause: str): | ||||||
|  |         self.cause = cause | ||||||
|  |         return self | ||||||
|  |  | ||||||
|  |  | ||||||
| class RedirectUriError(OAuth2Error): | class RedirectUriError(OAuth2Error): | ||||||
|     """The request fails due to a missing, invalid, or mismatching |     """The request fails due to a missing, invalid, or mismatching | ||||||
|  | |||||||
| @ -1,10 +1,23 @@ | |||||||
|  | from django.contrib.auth.signals import user_logged_out | ||||||
| from django.db.models.signals import post_save, pre_delete | from django.db.models.signals import post_save, pre_delete | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
|  | from django.http import HttpRequest | ||||||
|  |  | ||||||
| from authentik.core.models import AuthenticatedSession, User | from authentik.core.models import AuthenticatedSession, User | ||||||
| from authentik.providers.oauth2.models import AccessToken, DeviceToken, RefreshToken | from authentik.providers.oauth2.models import AccessToken, DeviceToken, RefreshToken | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(user_logged_out) | ||||||
|  | def user_logged_out_oauth_tokens_removal(sender, request: HttpRequest, user: User, **_): | ||||||
|  |     """Revoke tokens upon user logout""" | ||||||
|  |     if not request.session or not request.session.session_key: | ||||||
|  |         return | ||||||
|  |     AccessToken.objects.filter( | ||||||
|  |         user=user, | ||||||
|  |         session__session__session_key=request.session.session_key, | ||||||
|  |     ).delete() | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_delete, sender=AuthenticatedSession) | @receiver(pre_delete, sender=AuthenticatedSession) | ||||||
| def user_session_deleted_oauth_tokens_removal(sender, instance: AuthenticatedSession, **_): | def user_session_deleted_oauth_tokens_removal(sender, instance: AuthenticatedSession, **_): | ||||||
|     """Revoke tokens upon user logout""" |     """Revoke tokens upon user logout""" | ||||||
|  | |||||||
| @ -12,7 +12,7 @@ from authentik.core.tests.utils import create_test_admin_user, create_test_flow | |||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.lib.utils.time import timedelta_from_string | from authentik.lib.utils.time import timedelta_from_string | ||||||
| from authentik.providers.oauth2.constants import TOKEN_TYPE | from authentik.providers.oauth2.constants import SCOPE_OFFLINE_ACCESS, SCOPE_OPENID, TOKEN_TYPE | ||||||
| from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError | from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError | ||||||
| from authentik.providers.oauth2.models import ( | from authentik.providers.oauth2.models import ( | ||||||
|     AccessToken, |     AccessToken, | ||||||
| @ -43,7 +43,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")], |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")], | ||||||
|         ) |         ) | ||||||
|         with self.assertRaises(AuthorizeError): |         with self.assertRaises(AuthorizeError) as cm: | ||||||
|             request = self.factory.get( |             request = self.factory.get( | ||||||
|                 "/", |                 "/", | ||||||
|                 data={ |                 data={ | ||||||
| @ -53,6 +53,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|             OAuthAuthorizationParams.from_request(request) |             OAuthAuthorizationParams.from_request(request) | ||||||
|  |         self.assertEqual(cm.exception.error, "unsupported_response_type") | ||||||
|  |  | ||||||
|     def test_invalid_client_id(self): |     def test_invalid_client_id(self): | ||||||
|         """Test invalid client ID""" |         """Test invalid client ID""" | ||||||
| @ -68,7 +69,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")], |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")], | ||||||
|         ) |         ) | ||||||
|         with self.assertRaises(AuthorizeError): |         with self.assertRaises(AuthorizeError) as cm: | ||||||
|             request = self.factory.get( |             request = self.factory.get( | ||||||
|                 "/", |                 "/", | ||||||
|                 data={ |                 data={ | ||||||
| @ -79,19 +80,30 @@ class TestAuthorize(OAuthTestCase): | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|             OAuthAuthorizationParams.from_request(request) |             OAuthAuthorizationParams.from_request(request) | ||||||
|  |         self.assertEqual(cm.exception.error, "request_not_supported") | ||||||
|  |  | ||||||
|     def test_invalid_redirect_uri(self): |     def test_invalid_redirect_uri_missing(self): | ||||||
|         """test missing/invalid redirect URI""" |         """test missing redirect URI""" | ||||||
|         OAuth2Provider.objects.create( |         OAuth2Provider.objects.create( | ||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||||
|         ) |         ) | ||||||
|         with self.assertRaises(RedirectUriError): |         with self.assertRaises(RedirectUriError) as cm: | ||||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) |             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) | ||||||
|             OAuthAuthorizationParams.from_request(request) |             OAuthAuthorizationParams.from_request(request) | ||||||
|         with self.assertRaises(RedirectUriError): |         self.assertEqual(cm.exception.cause, "redirect_uri_missing") | ||||||
|  |  | ||||||
|  |     def test_invalid_redirect_uri(self): | ||||||
|  |         """test invalid redirect URI""" | ||||||
|  |         OAuth2Provider.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             client_id="test", | ||||||
|  |             authorization_flow=create_test_flow(), | ||||||
|  |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||||
|  |         ) | ||||||
|  |         with self.assertRaises(RedirectUriError) as cm: | ||||||
|             request = self.factory.get( |             request = self.factory.get( | ||||||
|                 "/", |                 "/", | ||||||
|                 data={ |                 data={ | ||||||
| @ -101,6 +113,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|             OAuthAuthorizationParams.from_request(request) |             OAuthAuthorizationParams.from_request(request) | ||||||
|  |         self.assertEqual(cm.exception.cause, "redirect_uri_no_match") | ||||||
|  |  | ||||||
|     def test_blocked_redirect_uri(self): |     def test_blocked_redirect_uri(self): | ||||||
|         """test missing/invalid redirect URI""" |         """test missing/invalid redirect URI""" | ||||||
| @ -108,9 +121,9 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:local.invalid")], |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:localhost")], | ||||||
|         ) |         ) | ||||||
|         with self.assertRaises(RedirectUriError): |         with self.assertRaises(RedirectUriError) as cm: | ||||||
|             request = self.factory.get( |             request = self.factory.get( | ||||||
|                 "/", |                 "/", | ||||||
|                 data={ |                 data={ | ||||||
| @ -120,6 +133,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|             OAuthAuthorizationParams.from_request(request) |             OAuthAuthorizationParams.from_request(request) | ||||||
|  |         self.assertEqual(cm.exception.cause, "redirect_uri_forbidden_scheme") | ||||||
|  |  | ||||||
|     def test_invalid_redirect_uri_empty(self): |     def test_invalid_redirect_uri_empty(self): | ||||||
|         """test missing/invalid redirect URI""" |         """test missing/invalid redirect URI""" | ||||||
| @ -129,9 +143,6 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris=[], |             redirect_uris=[], | ||||||
|         ) |         ) | ||||||
|         with self.assertRaises(RedirectUriError): |  | ||||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) |  | ||||||
|             OAuthAuthorizationParams.from_request(request) |  | ||||||
|         request = self.factory.get( |         request = self.factory.get( | ||||||
|             "/", |             "/", | ||||||
|             data={ |             data={ | ||||||
| @ -150,12 +161,9 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid?")], |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "http://local.invalid?")], | ||||||
|         ) |         ) | ||||||
|         with self.assertRaises(RedirectUriError): |         with self.assertRaises(RedirectUriError) as cm: | ||||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) |  | ||||||
|             OAuthAuthorizationParams.from_request(request) |  | ||||||
|         with self.assertRaises(RedirectUriError): |  | ||||||
|             request = self.factory.get( |             request = self.factory.get( | ||||||
|                 "/", |                 "/", | ||||||
|                 data={ |                 data={ | ||||||
| @ -165,6 +173,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|             OAuthAuthorizationParams.from_request(request) |             OAuthAuthorizationParams.from_request(request) | ||||||
|  |         self.assertEqual(cm.exception.cause, "redirect_uri_no_match") | ||||||
|  |  | ||||||
|     def test_redirect_uri_invalid_regex(self): |     def test_redirect_uri_invalid_regex(self): | ||||||
|         """test missing/invalid redirect URI (invalid regex)""" |         """test missing/invalid redirect URI (invalid regex)""" | ||||||
| @ -172,12 +181,9 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "+")], |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "+")], | ||||||
|         ) |         ) | ||||||
|         with self.assertRaises(RedirectUriError): |         with self.assertRaises(RedirectUriError) as cm: | ||||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) |  | ||||||
|             OAuthAuthorizationParams.from_request(request) |  | ||||||
|         with self.assertRaises(RedirectUriError): |  | ||||||
|             request = self.factory.get( |             request = self.factory.get( | ||||||
|                 "/", |                 "/", | ||||||
|                 data={ |                 data={ | ||||||
| @ -187,23 +193,22 @@ class TestAuthorize(OAuthTestCase): | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|             OAuthAuthorizationParams.from_request(request) |             OAuthAuthorizationParams.from_request(request) | ||||||
|  |         self.assertEqual(cm.exception.cause, "redirect_uri_no_match") | ||||||
|  |  | ||||||
|     def test_empty_redirect_uri(self): |     def test_redirect_uri_regex(self): | ||||||
|         """test empty redirect URI (configure in provider)""" |         """test valid redirect URI (regex)""" | ||||||
|         OAuth2Provider.objects.create( |         OAuth2Provider.objects.create( | ||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|  |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, ".+")], | ||||||
|         ) |         ) | ||||||
|         with self.assertRaises(RedirectUriError): |  | ||||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) |  | ||||||
|             OAuthAuthorizationParams.from_request(request) |  | ||||||
|         request = self.factory.get( |         request = self.factory.get( | ||||||
|             "/", |             "/", | ||||||
|             data={ |             data={ | ||||||
|                 "response_type": "code", |                 "response_type": "code", | ||||||
|                 "client_id": "test", |                 "client_id": "test", | ||||||
|                 "redirect_uri": "http://localhost", |                 "redirect_uri": "http://foo.bar.baz", | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|         OAuthAuthorizationParams.from_request(request) |         OAuthAuthorizationParams.from_request(request) | ||||||
| @ -258,7 +263,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             GrantTypes.IMPLICIT, |             GrantTypes.IMPLICIT, | ||||||
|         ) |         ) | ||||||
|         # Implicit without openid scope |         # Implicit without openid scope | ||||||
|         with self.assertRaises(AuthorizeError): |         with self.assertRaises(AuthorizeError) as cm: | ||||||
|             request = self.factory.get( |             request = self.factory.get( | ||||||
|                 "/", |                 "/", | ||||||
|                 data={ |                 data={ | ||||||
| @ -285,7 +290,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID |             OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID | ||||||
|         ) |         ) | ||||||
|         with self.assertRaises(AuthorizeError): |         with self.assertRaises(AuthorizeError) as cm: | ||||||
|             request = self.factory.get( |             request = self.factory.get( | ||||||
|                 "/", |                 "/", | ||||||
|                 data={ |                 data={ | ||||||
| @ -295,6 +300,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|             OAuthAuthorizationParams.from_request(request) |             OAuthAuthorizationParams.from_request(request) | ||||||
|  |         self.assertEqual(cm.exception.error, "unsupported_response_type") | ||||||
|  |  | ||||||
|     def test_full_code(self): |     def test_full_code(self): | ||||||
|         """Test full authorization""" |         """Test full authorization""" | ||||||
| @ -387,7 +393,8 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 response.url, |                 response.url, | ||||||
|                 ( |                 ( | ||||||
|                     f"http://localhost#id_token={provider.encode(token.id_token.to_dict())}" |                     f"http://localhost#access_token={token.token}" | ||||||
|  |                     f"&id_token={provider.encode(token.id_token.to_dict())}" | ||||||
|                     f"&token_type={TOKEN_TYPE}" |                     f"&token_type={TOKEN_TYPE}" | ||||||
|                     f"&expires_in={int(expires)}&state={state}" |                     f"&expires_in={int(expires)}&state={state}" | ||||||
|                 ), |                 ), | ||||||
| @ -562,6 +569,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|                 "url": "http://localhost", |                 "url": "http://localhost", | ||||||
|                 "title": f"Redirecting to {app.name}...", |                 "title": f"Redirecting to {app.name}...", | ||||||
|                 "attrs": { |                 "attrs": { | ||||||
|  |                     "access_token": token.token, | ||||||
|                     "id_token": provider.encode(token.id_token.to_dict()), |                     "id_token": provider.encode(token.id_token.to_dict()), | ||||||
|                     "token_type": TOKEN_TYPE, |                     "token_type": TOKEN_TYPE, | ||||||
|                     "expires_in": "3600", |                     "expires_in": "3600", | ||||||
| @ -613,3 +621,54 @@ class TestAuthorize(OAuthTestCase): | |||||||
|                 }, |                 }, | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_openid_missing_invalid(self): | ||||||
|  |         """test request requiring an OpenID scope to be set""" | ||||||
|  |         OAuth2Provider.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             client_id="test", | ||||||
|  |             authorization_flow=create_test_flow(), | ||||||
|  |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")], | ||||||
|  |         ) | ||||||
|  |         request = self.factory.get( | ||||||
|  |             "/", | ||||||
|  |             data={ | ||||||
|  |                 "response_type": "id_token", | ||||||
|  |                 "client_id": "test", | ||||||
|  |                 "redirect_uri": "http://localhost", | ||||||
|  |                 "scope": "", | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |         with self.assertRaises(AuthorizeError) as cm: | ||||||
|  |             OAuthAuthorizationParams.from_request(request) | ||||||
|  |         self.assertEqual(cm.exception.cause, "scope_openid_missing") | ||||||
|  |  | ||||||
|  |     @apply_blueprint("system/providers-oauth2.yaml") | ||||||
|  |     def test_offline_access_invalid(self): | ||||||
|  |         """test request for offline_access with invalid response type""" | ||||||
|  |         provider = OAuth2Provider.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             client_id="test", | ||||||
|  |             authorization_flow=create_test_flow(), | ||||||
|  |             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")], | ||||||
|  |         ) | ||||||
|  |         provider.property_mappings.set( | ||||||
|  |             ScopeMapping.objects.filter( | ||||||
|  |                 managed__in=[ | ||||||
|  |                     "goauthentik.io/providers/oauth2/scope-openid", | ||||||
|  |                     "goauthentik.io/providers/oauth2/scope-offline_access", | ||||||
|  |                 ] | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         request = self.factory.get( | ||||||
|  |             "/", | ||||||
|  |             data={ | ||||||
|  |                 "response_type": "id_token", | ||||||
|  |                 "client_id": "test", | ||||||
|  |                 "redirect_uri": "http://localhost", | ||||||
|  |                 "scope": f"{SCOPE_OPENID} {SCOPE_OFFLINE_ACCESS}", | ||||||
|  |                 "nonce": generate_id(), | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |         parsed = OAuthAuthorizationParams.from_request(request) | ||||||
|  |         self.assertNotIn(SCOPE_OFFLINE_ACCESS, parsed.scope) | ||||||
|  | |||||||
| @ -150,12 +150,12 @@ class OAuthAuthorizationParams: | |||||||
|         self.check_redirect_uri() |         self.check_redirect_uri() | ||||||
|         self.check_grant() |         self.check_grant() | ||||||
|         self.check_scope(github_compat) |         self.check_scope(github_compat) | ||||||
|  |         self.check_nonce() | ||||||
|  |         self.check_code_challenge() | ||||||
|         if self.request: |         if self.request: | ||||||
|             raise AuthorizeError( |             raise AuthorizeError( | ||||||
|                 self.redirect_uri, "request_not_supported", self.grant_type, self.state |                 self.redirect_uri, "request_not_supported", self.grant_type, self.state | ||||||
|             ) |             ) | ||||||
|         self.check_nonce() |  | ||||||
|         self.check_code_challenge() |  | ||||||
|  |  | ||||||
|     def check_grant(self): |     def check_grant(self): | ||||||
|         """Check grant""" |         """Check grant""" | ||||||
| @ -190,7 +190,7 @@ class OAuthAuthorizationParams: | |||||||
|         allowed_redirect_urls = self.provider.redirect_uris |         allowed_redirect_urls = self.provider.redirect_uris | ||||||
|         if not self.redirect_uri: |         if not self.redirect_uri: | ||||||
|             LOGGER.warning("Missing redirect uri.") |             LOGGER.warning("Missing redirect uri.") | ||||||
|             raise RedirectUriError("", allowed_redirect_urls) |             raise RedirectUriError("", allowed_redirect_urls).with_cause("redirect_uri_missing") | ||||||
|  |  | ||||||
|         if len(allowed_redirect_urls) < 1: |         if len(allowed_redirect_urls) < 1: | ||||||
|             LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri) |             LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri) | ||||||
| @ -219,10 +219,14 @@ class OAuthAuthorizationParams: | |||||||
|                         provider=self.provider, |                         provider=self.provider, | ||||||
|                     ) |                     ) | ||||||
|         if not match_found: |         if not match_found: | ||||||
|             raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) |             raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause( | ||||||
|  |                 "redirect_uri_no_match" | ||||||
|  |             ) | ||||||
|         # Check against forbidden schemes |         # Check against forbidden schemes | ||||||
|         if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES: |         if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES: | ||||||
|             raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) |             raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause( | ||||||
|  |                 "redirect_uri_forbidden_scheme" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|     def check_scope(self, github_compat=False): |     def check_scope(self, github_compat=False): | ||||||
|         """Ensure openid scope is set in Hybrid flows, or when requesting an id_token""" |         """Ensure openid scope is set in Hybrid flows, or when requesting an id_token""" | ||||||
| @ -251,7 +255,9 @@ class OAuthAuthorizationParams: | |||||||
|             or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN] |             or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN] | ||||||
|         ): |         ): | ||||||
|             LOGGER.warning("Missing 'openid' scope.") |             LOGGER.warning("Missing 'openid' scope.") | ||||||
|             raise AuthorizeError(self.redirect_uri, "invalid_scope", self.grant_type, self.state) |             raise AuthorizeError( | ||||||
|  |                 self.redirect_uri, "invalid_scope", self.grant_type, self.state | ||||||
|  |             ).with_cause("scope_openid_missing") | ||||||
|         if SCOPE_OFFLINE_ACCESS in self.scope: |         if SCOPE_OFFLINE_ACCESS in self.scope: | ||||||
|             # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess |             # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess | ||||||
|             # Don't explicitly request consent with offline_access, as the spec allows for |             # Don't explicitly request consent with offline_access, as the spec allows for | ||||||
| @ -286,7 +292,9 @@ class OAuthAuthorizationParams: | |||||||
|             return |             return | ||||||
|         if not self.nonce: |         if not self.nonce: | ||||||
|             LOGGER.warning("Missing nonce for OpenID Request") |             LOGGER.warning("Missing nonce for OpenID Request") | ||||||
|             raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state) |             raise AuthorizeError( | ||||||
|  |                 self.redirect_uri, "invalid_request", self.grant_type, self.state | ||||||
|  |             ).with_cause("none_missing") | ||||||
|  |  | ||||||
|     def check_code_challenge(self): |     def check_code_challenge(self): | ||||||
|         """PKCE validation of the transformation method.""" |         """PKCE validation of the transformation method.""" | ||||||
| @ -345,10 +353,10 @@ class AuthorizationFlowInitView(PolicyAccessView): | |||||||
|                 self.request, github_compat=self.github_compat |                 self.request, github_compat=self.github_compat | ||||||
|             ) |             ) | ||||||
|         except AuthorizeError as error: |         except AuthorizeError as error: | ||||||
|             LOGGER.warning(error.description, redirect_uri=error.redirect_uri) |             LOGGER.warning(error.description, redirect_uri=error.redirect_uri, cause=error.cause) | ||||||
|             raise RequestValidationError(error.get_response(self.request)) from None |             raise RequestValidationError(error.get_response(self.request)) from None | ||||||
|         except OAuth2Error as error: |         except OAuth2Error as error: | ||||||
|             LOGGER.warning(error.description) |             LOGGER.warning(error.description, cause=error.cause) | ||||||
|             raise RequestValidationError( |             raise RequestValidationError( | ||||||
|                 bad_request_message(self.request, error.description, title=error.error) |                 bad_request_message(self.request, error.description, title=error.error) | ||||||
|             ) from None |             ) from None | ||||||
| @ -630,6 +638,7 @@ class OAuthFulfillmentStage(StageView): | |||||||
|         if self.params.response_type in [ |         if self.params.response_type in [ | ||||||
|             ResponseTypes.ID_TOKEN_TOKEN, |             ResponseTypes.ID_TOKEN_TOKEN, | ||||||
|             ResponseTypes.CODE_ID_TOKEN_TOKEN, |             ResponseTypes.CODE_ID_TOKEN_TOKEN, | ||||||
|  |             ResponseTypes.ID_TOKEN, | ||||||
|             ResponseTypes.CODE_TOKEN, |             ResponseTypes.CODE_TOKEN, | ||||||
|         ]: |         ]: | ||||||
|             query_fragment["access_token"] = token.token |             query_fragment["access_token"] = token.token | ||||||
|  | |||||||
| @ -555,8 +555,6 @@ class TokenView(View): | |||||||
|  |  | ||||||
|     provider: OAuth2Provider | None = None |     provider: OAuth2Provider | None = None | ||||||
|     params: TokenParams | None = None |     params: TokenParams | None = None | ||||||
|     params_class = TokenParams |  | ||||||
|     provider_class = OAuth2Provider |  | ||||||
|  |  | ||||||
|     def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: |     def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: | ||||||
|         response = super().dispatch(request, *args, **kwargs) |         response = super().dispatch(request, *args, **kwargs) | ||||||
| @ -576,14 +574,12 @@ class TokenView(View): | |||||||
|                 op="authentik.providers.oauth2.post.parse", |                 op="authentik.providers.oauth2.post.parse", | ||||||
|             ): |             ): | ||||||
|                 client_id, client_secret = extract_client_auth(request) |                 client_id, client_secret = extract_client_auth(request) | ||||||
|                 self.provider = self.provider_class.objects.filter(client_id=client_id).first() |                 self.provider = OAuth2Provider.objects.filter(client_id=client_id).first() | ||||||
|                 if not self.provider: |                 if not self.provider: | ||||||
|                     LOGGER.warning("OAuth2Provider does not exist", client_id=client_id) |                     LOGGER.warning("OAuth2Provider does not exist", client_id=client_id) | ||||||
|                     raise TokenError("invalid_client") |                     raise TokenError("invalid_client") | ||||||
|                 CTX_AUTH_VIA.set("oauth_client_secret") |                 CTX_AUTH_VIA.set("oauth_client_secret") | ||||||
|                 self.params = self.params_class.parse( |                 self.params = TokenParams.parse(request, self.provider, client_id, client_secret) | ||||||
|                     request, self.provider, client_id, client_secret |  | ||||||
|                 ) |  | ||||||
|  |  | ||||||
|             with start_span( |             with start_span( | ||||||
|                 op="authentik.providers.oauth2.post.response", |                 op="authentik.providers.oauth2.post.response", | ||||||
|  | |||||||
| @ -66,10 +66,7 @@ class RACClientConsumer(AsyncWebsocketConsumer): | |||||||
|     def init_outpost_connection(self): |     def init_outpost_connection(self): | ||||||
|         """Initialize guac connection settings""" |         """Initialize guac connection settings""" | ||||||
|         self.token = ( |         self.token = ( | ||||||
|             ConnectionToken.filter_not_expired( |             ConnectionToken.filter_not_expired(token=self.scope["url_route"]["kwargs"]["token"]) | ||||||
|                 token=self.scope["url_route"]["kwargs"]["token"], |  | ||||||
|                 session__session__session_key=self.scope["session"].session_key, |  | ||||||
|             ) |  | ||||||
|             .select_related("endpoint", "provider", "session", "session__user") |             .select_related("endpoint", "provider", "session", "session__user") | ||||||
|             .first() |             .first() | ||||||
|         ) |         ) | ||||||
|  | |||||||
| @ -2,11 +2,13 @@ | |||||||
|  |  | ||||||
| from asgiref.sync import async_to_sync | from asgiref.sync import async_to_sync | ||||||
| from channels.layers import get_channel_layer | from channels.layers import get_channel_layer | ||||||
|  | from django.contrib.auth.signals import user_logged_out | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.db.models.signals import post_delete, post_save, pre_delete | from django.db.models.signals import post_delete, post_save, pre_delete | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
|  | from django.http import HttpRequest | ||||||
|  |  | ||||||
| from authentik.core.models import AuthenticatedSession | from authentik.core.models import AuthenticatedSession, User | ||||||
| from authentik.providers.rac.api.endpoints import user_endpoint_cache_key | from authentik.providers.rac.api.endpoints import user_endpoint_cache_key | ||||||
| from authentik.providers.rac.consumer_client import ( | from authentik.providers.rac.consumer_client import ( | ||||||
|     RAC_CLIENT_GROUP_SESSION, |     RAC_CLIENT_GROUP_SESSION, | ||||||
| @ -15,6 +17,21 @@ from authentik.providers.rac.consumer_client import ( | |||||||
| from authentik.providers.rac.models import ConnectionToken, Endpoint | from authentik.providers.rac.models import ConnectionToken, Endpoint | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(user_logged_out) | ||||||
|  | def user_logged_out_session(sender, request: HttpRequest, user: User, **_): | ||||||
|  |     """Disconnect any open RAC connections""" | ||||||
|  |     if not request.session or not request.session.session_key: | ||||||
|  |         return | ||||||
|  |     layer = get_channel_layer() | ||||||
|  |     async_to_sync(layer.group_send)( | ||||||
|  |         RAC_CLIENT_GROUP_SESSION | ||||||
|  |         % { | ||||||
|  |             "session": request.session.session_key, | ||||||
|  |         }, | ||||||
|  |         {"type": "event.disconnect", "reason": "session_logout"}, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_delete, sender=AuthenticatedSession) | @receiver(pre_delete, sender=AuthenticatedSession) | ||||||
| def user_session_deleted(sender, instance: AuthenticatedSession, **_): | def user_session_deleted(sender, instance: AuthenticatedSession, **_): | ||||||
|     layer = get_channel_layer() |     layer = get_channel_layer() | ||||||
|  | |||||||
| @ -49,7 +49,6 @@ class TestEndpointsAPI(APITestCase): | |||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             response.content.decode(), |             response.content.decode(), | ||||||
|             { |             { | ||||||
|                 "autocomplete": {}, |  | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                     "next": 0, |                     "next": 0, | ||||||
|                     "previous": 0, |                     "previous": 0, | ||||||
| @ -102,7 +101,6 @@ class TestEndpointsAPI(APITestCase): | |||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             response.content.decode(), |             response.content.decode(), | ||||||
|             { |             { | ||||||
|                 "autocomplete": {}, |  | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                     "next": 0, |                     "next": 0, | ||||||
|                     "previous": 0, |                     "previous": 0, | ||||||
|  | |||||||
| @ -87,22 +87,3 @@ class TestRACViews(APITestCase): | |||||||
|         ) |         ) | ||||||
|         body = loads(flow_response.content) |         body = loads(flow_response.content) | ||||||
|         self.assertEqual(body["component"], "ak-stage-access-denied") |         self.assertEqual(body["component"], "ak-stage-access-denied") | ||||||
|  |  | ||||||
|     def test_different_session(self): |  | ||||||
|         """Test request""" |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|         response = self.client.get( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_providers_rac:start", |  | ||||||
|                 kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)}, |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 302) |  | ||||||
|         flow_response = self.client.get( |  | ||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) |  | ||||||
|         ) |  | ||||||
|         body = loads(flow_response.content) |  | ||||||
|         next_url = body["to"] |  | ||||||
|         self.client.logout() |  | ||||||
|         final_response = self.client.get(next_url) |  | ||||||
|         self.assertEqual(final_response.url, reverse("authentik_core:if-user")) |  | ||||||
|  | |||||||
| @ -68,10 +68,7 @@ class RACInterface(InterfaceView): | |||||||
|  |  | ||||||
|     def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: |     def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: | ||||||
|         # Early sanity check to ensure token still exists |         # Early sanity check to ensure token still exists | ||||||
|         token = ConnectionToken.filter_not_expired( |         token = ConnectionToken.filter_not_expired(token=self.kwargs["token"]).first() | ||||||
|             token=self.kwargs["token"], |  | ||||||
|             session__session__session_key=request.session.session_key, |  | ||||||
|         ).first() |  | ||||||
|         if not token: |         if not token: | ||||||
|             return redirect("authentik_core:if-user") |             return redirect("authentik_core:if-user") | ||||||
|         self.token = token |         self.token = token | ||||||
|  | |||||||
| @ -5,6 +5,7 @@ from itertools import batched | |||||||
| from django.db import transaction | from django.db import transaction | ||||||
| from pydantic import ValidationError | from pydantic import ValidationError | ||||||
| from pydanticscim.group import GroupMember | from pydanticscim.group import GroupMember | ||||||
|  | from pydanticscim.responses import PatchOp | ||||||
|  |  | ||||||
| from authentik.core.models import Group | from authentik.core.models import Group | ||||||
| from authentik.lib.sync.mapper import PropertyMappingManager | from authentik.lib.sync.mapper import PropertyMappingManager | ||||||
| @ -19,12 +20,7 @@ from authentik.providers.scim.clients.base import SCIMClient | |||||||
| from authentik.providers.scim.clients.exceptions import ( | from authentik.providers.scim.clients.exceptions import ( | ||||||
|     SCIMRequestException, |     SCIMRequestException, | ||||||
| ) | ) | ||||||
| from authentik.providers.scim.clients.schema import ( | from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest | ||||||
|     SCIM_GROUP_SCHEMA, |  | ||||||
|     PatchOp, |  | ||||||
|     PatchOperation, |  | ||||||
|     PatchRequest, |  | ||||||
| ) |  | ||||||
| from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema | from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema | ||||||
| from authentik.providers.scim.models import ( | from authentik.providers.scim.models import ( | ||||||
|     SCIMMapping, |     SCIMMapping, | ||||||
|  | |||||||
| @ -1,7 +1,5 @@ | |||||||
| """Custom SCIM schemas""" | """Custom SCIM schemas""" | ||||||
|  |  | ||||||
| from enum import Enum |  | ||||||
|  |  | ||||||
| from pydantic import Field | from pydantic import Field | ||||||
| from pydanticscim.group import Group as BaseGroup | from pydanticscim.group import Group as BaseGroup | ||||||
| from pydanticscim.responses import PatchOperation as BasePatchOperation | from pydanticscim.responses import PatchOperation as BasePatchOperation | ||||||
| @ -67,21 +65,6 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class PatchOp(str, Enum): |  | ||||||
|  |  | ||||||
|     replace = "replace" |  | ||||||
|     remove = "remove" |  | ||||||
|     add = "add" |  | ||||||
|  |  | ||||||
|     @classmethod |  | ||||||
|     def _missing_(cls, value): |  | ||||||
|         value = value.lower() |  | ||||||
|         for member in cls: |  | ||||||
|             if member.lower() == value: |  | ||||||
|                 return member |  | ||||||
|         return None |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class PatchRequest(BasePatchRequest): | class PatchRequest(BasePatchRequest): | ||||||
|     """PatchRequest which correctly sets schemas""" |     """PatchRequest which correctly sets schemas""" | ||||||
|  |  | ||||||
| @ -91,7 +74,6 @@ class PatchRequest(BasePatchRequest): | |||||||
| class PatchOperation(BasePatchOperation): | class PatchOperation(BasePatchOperation): | ||||||
|     """PatchOperation with optional path""" |     """PatchOperation with optional path""" | ||||||
|  |  | ||||||
|     op: PatchOp |  | ||||||
|     path: str | None |     path: str | None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -44,7 +44,6 @@ class TestRBACRoleAPI(APITestCase): | |||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             res.content.decode(), |             res.content.decode(), | ||||||
|             { |             { | ||||||
|                 "autocomplete": {}, |  | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                     "next": 0, |                     "next": 0, | ||||||
|                     "previous": 0, |                     "previous": 0, | ||||||
|  | |||||||
| @ -46,7 +46,6 @@ class TestRBACUserAPI(APITestCase): | |||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             res.content.decode(), |             res.content.decode(), | ||||||
|             { |             { | ||||||
|                 "autocomplete": {}, |  | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                     "next": 0, |                     "next": 0, | ||||||
|                     "previous": 0, |                     "previous": 0, | ||||||
|  | |||||||
| @ -38,7 +38,6 @@ class TestAPIPerms(APITestCase): | |||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             res.content.decode(), |             res.content.decode(), | ||||||
|             { |             { | ||||||
|                 "autocomplete": {}, |  | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                     "next": 0, |                     "next": 0, | ||||||
|                     "previous": 0, |                     "previous": 0, | ||||||
| @ -74,7 +73,6 @@ class TestAPIPerms(APITestCase): | |||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             res.content.decode(), |             res.content.decode(), | ||||||
|             { |             { | ||||||
|                 "autocomplete": {}, |  | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                     "next": 0, |                     "next": 0, | ||||||
|                     "previous": 0, |                     "previous": 0, | ||||||
|  | |||||||
| @ -9,14 +9,13 @@ https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/ | |||||||
|  |  | ||||||
| import django | import django | ||||||
| from channels.routing import ProtocolTypeRouter, URLRouter | from channels.routing import ProtocolTypeRouter, URLRouter | ||||||
|  | from defusedxml import defuse_stdlib | ||||||
| from django.core.asgi import get_asgi_application | from django.core.asgi import get_asgi_application | ||||||
| from sentry_sdk.integrations.asgi import SentryAsgiMiddleware | from sentry_sdk.integrations.asgi import SentryAsgiMiddleware | ||||||
|  |  | ||||||
| from authentik.root.setup import setup |  | ||||||
|  |  | ||||||
| # DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py | # DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py | ||||||
|  |  | ||||||
| setup() | defuse_stdlib() | ||||||
| django.setup() | django.setup() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -27,7 +27,7 @@ from structlog.stdlib import get_logger | |||||||
| from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp | from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp | ||||||
|  |  | ||||||
| from authentik import get_full_version | from authentik import get_full_version | ||||||
| from authentik.lib.sentry import should_ignore_exception | from authentik.lib.sentry import before_send | ||||||
| from authentik.lib.utils.errors import exception_to_string | from authentik.lib.utils.errors import exception_to_string | ||||||
|  |  | ||||||
| # set the default Django settings module for the 'celery' program. | # set the default Django settings module for the 'celery' program. | ||||||
| @ -81,7 +81,7 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar | |||||||
|  |  | ||||||
|     LOGGER.warning("Task failure", task_id=task_id.replace("-", ""), exc=exception) |     LOGGER.warning("Task failure", task_id=task_id.replace("-", ""), exc=exception) | ||||||
|     CTX_TASK_ID.set(...) |     CTX_TASK_ID.set(...) | ||||||
|     if not should_ignore_exception(exception): |     if before_send({}, {"exc_info": (None, exception, None)}) is not None: | ||||||
|         Event.new( |         Event.new( | ||||||
|             EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id |             EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id | ||||||
|         ).save() |         ).save() | ||||||
|  | |||||||
| @ -1,49 +1,13 @@ | |||||||
| """authentik database backend""" | """authentik database backend""" | ||||||
|  |  | ||||||
| from django.core.checks import Warning |  | ||||||
| from django.db.backends.base.validation import BaseDatabaseValidation |  | ||||||
| from django_tenants.postgresql_backend.base import DatabaseWrapper as BaseDatabaseWrapper | from django_tenants.postgresql_backend.base import DatabaseWrapper as BaseDatabaseWrapper | ||||||
|  |  | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
|  |  | ||||||
|  |  | ||||||
| class DatabaseValidation(BaseDatabaseValidation): |  | ||||||
|  |  | ||||||
|     def check(self, **kwargs): |  | ||||||
|         return self._check_encoding() |  | ||||||
|  |  | ||||||
|     def _check_encoding(self): |  | ||||||
|         """Throw a warning when the server_encoding is not UTF-8 or |  | ||||||
|         server_encoding and client_encoding are mismatched""" |  | ||||||
|         messages = [] |  | ||||||
|         with self.connection.cursor() as cursor: |  | ||||||
|             cursor.execute("SHOW server_encoding;") |  | ||||||
|             server_encoding = cursor.fetchone()[0] |  | ||||||
|             cursor.execute("SHOW client_encoding;") |  | ||||||
|             client_encoding = cursor.fetchone()[0] |  | ||||||
|             if server_encoding != client_encoding: |  | ||||||
|                 messages.append( |  | ||||||
|                     Warning( |  | ||||||
|                         "PostgreSQL Server and Client encoding are mismatched: Server: " |  | ||||||
|                         f"{server_encoding}, Client: {client_encoding}", |  | ||||||
|                         id="ak.db.W001", |  | ||||||
|                     ) |  | ||||||
|                 ) |  | ||||||
|             if server_encoding != "UTF8": |  | ||||||
|                 messages.append( |  | ||||||
|                     Warning( |  | ||||||
|                         f"PostgreSQL Server encoding is not UTF8: {server_encoding}", |  | ||||||
|                         id="ak.db.W002", |  | ||||||
|                     ) |  | ||||||
|                 ) |  | ||||||
|         return messages |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DatabaseWrapper(BaseDatabaseWrapper): | class DatabaseWrapper(BaseDatabaseWrapper): | ||||||
|     """database backend which supports rotating credentials""" |     """database backend which supports rotating credentials""" | ||||||
|  |  | ||||||
|     validation_class = DatabaseValidation |  | ||||||
|  |  | ||||||
|     def get_connection_params(self): |     def get_connection_params(self): | ||||||
|         """Refresh DB credentials before getting connection params""" |         """Refresh DB credentials before getting connection params""" | ||||||
|         conn_params = super().get_connection_params() |         conn_params = super().get_connection_params() | ||||||
|  | |||||||
| @ -61,22 +61,6 @@ class SessionMiddleware(UpstreamSessionMiddleware): | |||||||
|             pass |             pass | ||||||
|         return session_key |         return session_key | ||||||
|  |  | ||||||
|     @staticmethod |  | ||||||
|     def encode_session(session_key: str, user: User): |  | ||||||
|         payload = { |  | ||||||
|             "sid": session_key, |  | ||||||
|             "iss": "authentik", |  | ||||||
|             "sub": "anonymous", |  | ||||||
|             "authenticated": user.is_authenticated, |  | ||||||
|             "acr": ACR_AUTHENTIK_SESSION, |  | ||||||
|         } |  | ||||||
|         if user.is_authenticated: |  | ||||||
|             payload["sub"] = user.uid |  | ||||||
|         value = encode(payload=payload, key=SIGNING_HASH) |  | ||||||
|         if settings.TEST: |  | ||||||
|             value = session_key |  | ||||||
|         return value |  | ||||||
|  |  | ||||||
|     def process_request(self, request: HttpRequest): |     def process_request(self, request: HttpRequest): | ||||||
|         raw_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME) |         raw_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME) | ||||||
|         session_key = SessionMiddleware.decode_session_key(raw_session) |         session_key = SessionMiddleware.decode_session_key(raw_session) | ||||||
| @ -133,9 +117,21 @@ class SessionMiddleware(UpstreamSessionMiddleware): | |||||||
|                             "request completed. The user may have logged " |                             "request completed. The user may have logged " | ||||||
|                             "out in a concurrent request, for example." |                             "out in a concurrent request, for example." | ||||||
|                         ) from None |                         ) from None | ||||||
|  |                     payload = { | ||||||
|  |                         "sid": request.session.session_key, | ||||||
|  |                         "iss": "authentik", | ||||||
|  |                         "sub": "anonymous", | ||||||
|  |                         "authenticated": request.user.is_authenticated, | ||||||
|  |                         "acr": ACR_AUTHENTIK_SESSION, | ||||||
|  |                     } | ||||||
|  |                     if request.user.is_authenticated: | ||||||
|  |                         payload["sub"] = request.user.uid | ||||||
|  |                     value = encode(payload=payload, key=SIGNING_HASH) | ||||||
|  |                     if settings.TEST: | ||||||
|  |                         value = request.session.session_key | ||||||
|                     response.set_cookie( |                     response.set_cookie( | ||||||
|                         settings.SESSION_COOKIE_NAME, |                         settings.SESSION_COOKIE_NAME, | ||||||
|                         SessionMiddleware.encode_session(request.session.session_key, request.user), |                         value, | ||||||
|                         max_age=max_age, |                         max_age=max_age, | ||||||
|                         expires=expires, |                         expires=expires, | ||||||
|                         domain=settings.SESSION_COOKIE_DOMAIN, |                         domain=settings.SESSION_COOKIE_DOMAIN, | ||||||
|  | |||||||
| @ -446,8 +446,6 @@ _DISALLOWED_ITEMS = [ | |||||||
|     "MIDDLEWARE", |     "MIDDLEWARE", | ||||||
|     "AUTHENTICATION_BACKENDS", |     "AUTHENTICATION_BACKENDS", | ||||||
|     "CELERY", |     "CELERY", | ||||||
|     "SPECTACULAR_SETTINGS", |  | ||||||
|     "REST_FRAMEWORK", |  | ||||||
| ] | ] | ||||||
|  |  | ||||||
| SILENCED_SYSTEM_CHECKS = [ | SILENCED_SYSTEM_CHECKS = [ | ||||||
| @ -470,8 +468,6 @@ def _update_settings(app_path: str): | |||||||
|         TENANT_APPS.extend(getattr(settings_module, "TENANT_APPS", [])) |         TENANT_APPS.extend(getattr(settings_module, "TENANT_APPS", [])) | ||||||
|         MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", [])) |         MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", [])) | ||||||
|         AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", [])) |         AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", [])) | ||||||
|         SPECTACULAR_SETTINGS.update(getattr(settings_module, "SPECTACULAR_SETTINGS", {})) |  | ||||||
|         REST_FRAMEWORK.update(getattr(settings_module, "REST_FRAMEWORK", {})) |  | ||||||
|         CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {})) |         CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {})) | ||||||
|         for _attr in dir(settings_module): |         for _attr in dir(settings_module): | ||||||
|             if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS: |             if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS: | ||||||
|  | |||||||
| @ -1,26 +0,0 @@ | |||||||
| import os |  | ||||||
| import warnings |  | ||||||
|  |  | ||||||
| from cryptography.hazmat.backends.openssl.backend import backend |  | ||||||
| from defusedxml import defuse_stdlib |  | ||||||
|  |  | ||||||
| from authentik.lib.config import CONFIG |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def setup(): |  | ||||||
|     warnings.filterwarnings("ignore", "SelectableGroups dict interface") |  | ||||||
|     warnings.filterwarnings( |  | ||||||
|         "ignore", |  | ||||||
|         "defusedxml.lxml is no longer supported and will be removed in a future release.", |  | ||||||
|     ) |  | ||||||
|     warnings.filterwarnings( |  | ||||||
|         "ignore", |  | ||||||
|         "defusedxml.cElementTree is deprecated, import from defusedxml.ElementTree instead.", |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     defuse_stdlib() |  | ||||||
|  |  | ||||||
|     if CONFIG.get_bool("compliance.fips.enabled", False): |  | ||||||
|         backend._enable_fips() |  | ||||||
|  |  | ||||||
|     os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings") |  | ||||||
| @ -11,8 +11,6 @@ from django.contrib.contenttypes.models import ContentType | |||||||
| from django.test.runner import DiscoverRunner | from django.test.runner import DiscoverRunner | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR |  | ||||||
| from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR |  | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.lib.sentry import sentry_init | from authentik.lib.sentry import sentry_init | ||||||
| from authentik.root.signals import post_startup, pre_startup, startup | from authentik.root.signals import post_startup, pre_startup, startup | ||||||
| @ -78,9 +76,6 @@ class PytestTestRunner(DiscoverRunner):  # pragma: no cover | |||||||
|         for key, value in test_config.items(): |         for key, value in test_config.items(): | ||||||
|             CONFIG.set(key, value) |             CONFIG.set(key, value) | ||||||
|  |  | ||||||
|         ASN_CONTEXT_PROCESSOR.load() |  | ||||||
|         GEOIP_CONTEXT_PROCESSOR.load() |  | ||||||
|  |  | ||||||
|         sentry_init() |         sentry_init() | ||||||
|         self.logger.debug("Test environment configured") |         self.logger.debug("Test environment configured") | ||||||
|  |  | ||||||
|  | |||||||
| @ -71,31 +71,37 @@ def ldap_sync_single(source_pk: str): | |||||||
|             return |             return | ||||||
|         # Delete all sync tasks from the cache |         # Delete all sync tasks from the cache | ||||||
|         DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete() |         DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete() | ||||||
|  |         task = chain( | ||||||
|         # The order of these operations needs to be preserved as each depends on the previous one(s) |             # User and group sync can happen at once, they have no dependencies on each other | ||||||
|         # 1. User and group sync can happen simultaneously |             group( | ||||||
|         # 2. Membership sync needs to run afterwards |                 ldap_sync_paginator(source, UserLDAPSynchronizer) | ||||||
|         # 3. Finally, user and group deletions can happen simultaneously |                 + ldap_sync_paginator(source, GroupLDAPSynchronizer), | ||||||
|         user_group_sync = ldap_sync_paginator(source, UserLDAPSynchronizer) + ldap_sync_paginator( |             ), | ||||||
|             source, GroupLDAPSynchronizer |             # Membership sync needs to run afterwards | ||||||
|  |             group( | ||||||
|  |                 ldap_sync_paginator(source, MembershipLDAPSynchronizer), | ||||||
|  |             ), | ||||||
|  |             # Finally, deletions. What we'd really like to do here is something like | ||||||
|  |             # ``` | ||||||
|  |             # user_identifiers = <ldap query> | ||||||
|  |             # User.objects.exclude( | ||||||
|  |             #     usersourceconnection__identifier__in=user_uniqueness_identifiers, | ||||||
|  |             # ).delete() | ||||||
|  |             # ``` | ||||||
|  |             # This runs into performance issues in large installations. So instead we spread the | ||||||
|  |             # work out into three steps: | ||||||
|  |             # 1. Get every object from the LDAP source. | ||||||
|  |             # 2. Mark every object as "safe" in the database. This is quick, but any error could | ||||||
|  |             #    mean deleting users which should not be deleted, so we do it immediately, in | ||||||
|  |             #    large chunks, and only queue the deletion step afterwards. | ||||||
|  |             # 3. Delete every unmarked item. This is slow, so we spread it over many tasks in | ||||||
|  |             #    small chunks. | ||||||
|  |             group( | ||||||
|  |                 ldap_sync_paginator(source, UserLDAPForwardDeletion) | ||||||
|  |                 + ldap_sync_paginator(source, GroupLDAPForwardDeletion), | ||||||
|  |             ), | ||||||
|         ) |         ) | ||||||
|         membership_sync = ldap_sync_paginator(source, MembershipLDAPSynchronizer) |         task() | ||||||
|         user_group_deletion = ldap_sync_paginator( |  | ||||||
|             source, UserLDAPForwardDeletion |  | ||||||
|         ) + ldap_sync_paginator(source, GroupLDAPForwardDeletion) |  | ||||||
|  |  | ||||||
|         # Celery is buggy with empty groups, so we are careful only to add non-empty groups. |  | ||||||
|         # See https://github.com/celery/celery/issues/9772 |  | ||||||
|         task_groups = [] |  | ||||||
|         if user_group_sync: |  | ||||||
|             task_groups.append(group(user_group_sync)) |  | ||||||
|         if membership_sync: |  | ||||||
|             task_groups.append(group(membership_sync)) |  | ||||||
|         if user_group_deletion: |  | ||||||
|             task_groups.append(group(user_group_deletion)) |  | ||||||
|  |  | ||||||
|         all_tasks = chain(task_groups) |  | ||||||
|         all_tasks() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list: | def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list: | ||||||
|  | |||||||
| @ -1,277 +0,0 @@ | |||||||
| """Test SCIM Group""" |  | ||||||
|  |  | ||||||
| from json import dumps |  | ||||||
| from uuid import uuid4 |  | ||||||
|  |  | ||||||
| from django.urls import reverse |  | ||||||
| from rest_framework.test import APITestCase |  | ||||||
|  |  | ||||||
| from authentik.core.models import Group |  | ||||||
| from authentik.core.tests.utils import create_test_user |  | ||||||
| from authentik.events.models import Event, EventAction |  | ||||||
| from authentik.lib.generators import generate_id |  | ||||||
| from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema |  | ||||||
| from authentik.sources.scim.models import ( |  | ||||||
|     SCIMSource, |  | ||||||
|     SCIMSourceGroup, |  | ||||||
| ) |  | ||||||
| from authentik.sources.scim.views.v2.base import SCIM_CONTENT_TYPE |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestSCIMGroups(APITestCase): |  | ||||||
|     """Test SCIM Group view""" |  | ||||||
|  |  | ||||||
|     def setUp(self) -> None: |  | ||||||
|         self.source = SCIMSource.objects.create(name=generate_id(), slug=generate_id()) |  | ||||||
|  |  | ||||||
|     def test_group_list(self): |  | ||||||
|         """Test full group list""" |  | ||||||
|         response = self.client.get( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|  |  | ||||||
|     def test_group_list_single(self): |  | ||||||
|         """Test full group list (single group)""" |  | ||||||
|         group = Group.objects.create(name=generate_id()) |  | ||||||
|         user = create_test_user() |  | ||||||
|         group.users.add(user) |  | ||||||
|         SCIMSourceGroup.objects.create( |  | ||||||
|             source=self.source, |  | ||||||
|             group=group, |  | ||||||
|             id=str(uuid4()), |  | ||||||
|         ) |  | ||||||
|         response = self.client.get( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                     "group_id": str(group.pk), |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, second=200) |  | ||||||
|         SCIMGroupSchema.model_validate_json(response.content, strict=True) |  | ||||||
|  |  | ||||||
|     def test_group_create(self): |  | ||||||
|         """Test group create""" |  | ||||||
|         ext_id = generate_id() |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             data=dumps({"displayName": generate_id(), "externalId": ext_id}), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 201) |  | ||||||
|         self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists()) |  | ||||||
|         self.assertTrue( |  | ||||||
|             Event.objects.filter( |  | ||||||
|                 action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username |  | ||||||
|             ).exists() |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_group_create_members(self): |  | ||||||
|         """Test group create""" |  | ||||||
|         user = create_test_user() |  | ||||||
|         ext_id = generate_id() |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             data=dumps( |  | ||||||
|                 { |  | ||||||
|                     "displayName": generate_id(), |  | ||||||
|                     "externalId": ext_id, |  | ||||||
|                     "members": [{"value": str(user.uuid)}], |  | ||||||
|                 } |  | ||||||
|             ), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 201) |  | ||||||
|         self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists()) |  | ||||||
|         self.assertTrue( |  | ||||||
|             Event.objects.filter( |  | ||||||
|                 action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username |  | ||||||
|             ).exists() |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_group_create_members_empty(self): |  | ||||||
|         """Test group create""" |  | ||||||
|         ext_id = generate_id() |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             data=dumps({"displayName": generate_id(), "externalId": ext_id, "members": []}), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 201) |  | ||||||
|         self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists()) |  | ||||||
|         self.assertTrue( |  | ||||||
|             Event.objects.filter( |  | ||||||
|                 action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username |  | ||||||
|             ).exists() |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_group_create_duplicate(self): |  | ||||||
|         """Test group create (duplicate)""" |  | ||||||
|         group = Group.objects.create(name=generate_id()) |  | ||||||
|         existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) |  | ||||||
|         ext_id = generate_id() |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             data=dumps( |  | ||||||
|                 {"displayName": generate_id(), "externalId": ext_id, "id": str(existing.group.pk)} |  | ||||||
|             ), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 409) |  | ||||||
|         self.assertJSONEqual( |  | ||||||
|             response.content, |  | ||||||
|             { |  | ||||||
|                 "detail": "Group with ID exists already.", |  | ||||||
|                 "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], |  | ||||||
|                 "scimType": "uniqueness", |  | ||||||
|                 "status": 409, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_group_update(self): |  | ||||||
|         """Test group update""" |  | ||||||
|         group = Group.objects.create(name=generate_id()) |  | ||||||
|         existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) |  | ||||||
|         ext_id = generate_id() |  | ||||||
|         response = self.client.put( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={"source_slug": self.source.slug, "group_id": group.pk}, |  | ||||||
|             ), |  | ||||||
|             data=dumps( |  | ||||||
|                 {"displayName": generate_id(), "externalId": ext_id, "id": str(existing.pk)} |  | ||||||
|             ), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, second=200) |  | ||||||
|  |  | ||||||
|     def test_group_update_non_existent(self): |  | ||||||
|         """Test group update""" |  | ||||||
|         ext_id = generate_id() |  | ||||||
|         response = self.client.put( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                     "group_id": str(uuid4()), |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             data=dumps({"displayName": generate_id(), "externalId": ext_id, "id": ""}), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, second=404) |  | ||||||
|         self.assertJSONEqual( |  | ||||||
|             response.content, |  | ||||||
|             { |  | ||||||
|                 "detail": "Group not found.", |  | ||||||
|                 "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], |  | ||||||
|                 "status": 404, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_group_patch_add(self): |  | ||||||
|         """Test group patch""" |  | ||||||
|         user = create_test_user() |  | ||||||
|  |  | ||||||
|         group = Group.objects.create(name=generate_id()) |  | ||||||
|         SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) |  | ||||||
|         response = self.client.patch( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={"source_slug": self.source.slug, "group_id": group.pk}, |  | ||||||
|             ), |  | ||||||
|             data=dumps( |  | ||||||
|                 { |  | ||||||
|                     "Operations": [ |  | ||||||
|                         { |  | ||||||
|                             "op": "Add", |  | ||||||
|                             "path": "members", |  | ||||||
|                             "value": {"value": str(user.uuid)}, |  | ||||||
|                         } |  | ||||||
|                     ] |  | ||||||
|                 } |  | ||||||
|             ), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, second=200) |  | ||||||
|         self.assertTrue(group.users.filter(pk=user.pk).exists()) |  | ||||||
|  |  | ||||||
|     def test_group_patch_remove(self): |  | ||||||
|         """Test group patch""" |  | ||||||
|         user = create_test_user() |  | ||||||
|  |  | ||||||
|         group = Group.objects.create(name=generate_id()) |  | ||||||
|         group.users.add(user) |  | ||||||
|         SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) |  | ||||||
|         response = self.client.patch( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={"source_slug": self.source.slug, "group_id": group.pk}, |  | ||||||
|             ), |  | ||||||
|             data=dumps( |  | ||||||
|                 { |  | ||||||
|                     "Operations": [ |  | ||||||
|                         { |  | ||||||
|                             "op": "remove", |  | ||||||
|                             "path": "members", |  | ||||||
|                             "value": {"value": str(user.uuid)}, |  | ||||||
|                         } |  | ||||||
|                     ] |  | ||||||
|                 } |  | ||||||
|             ), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, second=200) |  | ||||||
|         self.assertFalse(group.users.filter(pk=user.pk).exists()) |  | ||||||
|  |  | ||||||
|     def test_group_delete(self): |  | ||||||
|         """Test group delete""" |  | ||||||
|         group = Group.objects.create(name=generate_id()) |  | ||||||
|         SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) |  | ||||||
|         response = self.client.delete( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={"source_slug": self.source.slug, "group_id": group.pk}, |  | ||||||
|             ), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, second=204) |  | ||||||
| @ -177,51 +177,3 @@ class TestSCIMUsers(APITestCase): | |||||||
|             SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"], |             SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"], | ||||||
|             "0123456789", |             "0123456789", | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def test_user_update(self): |  | ||||||
|         """Test user update""" |  | ||||||
|         user = create_test_user() |  | ||||||
|         existing = SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4()) |  | ||||||
|         ext_id = generate_id() |  | ||||||
|         response = self.client.put( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-users", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                     "user_id": str(user.uuid), |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             data=dumps( |  | ||||||
|                 { |  | ||||||
|                     "id": str(existing.pk), |  | ||||||
|                     "userName": generate_id(), |  | ||||||
|                     "externalId": ext_id, |  | ||||||
|                     "emails": [ |  | ||||||
|                         { |  | ||||||
|                             "primary": True, |  | ||||||
|                             "value": user.email, |  | ||||||
|                         } |  | ||||||
|                     ], |  | ||||||
|                 } |  | ||||||
|             ), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|  |  | ||||||
|     def test_user_delete(self): |  | ||||||
|         """Test user delete""" |  | ||||||
|         user = create_test_user() |  | ||||||
|         SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4()) |  | ||||||
|         response = self.client.delete( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-users", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                     "user_id": str(user.uuid), |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 204) |  | ||||||
|  | |||||||
| @ -8,7 +8,6 @@ from rest_framework.authentication import BaseAuthentication, get_authorization_ | |||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| from rest_framework.views import APIView | from rest_framework.views import APIView | ||||||
|  |  | ||||||
| from authentik.core.middleware import CTX_AUTH_VIA |  | ||||||
| from authentik.core.models import Token, TokenIntents, User | from authentik.core.models import Token, TokenIntents, User | ||||||
| from authentik.sources.scim.models import SCIMSource | from authentik.sources.scim.models import SCIMSource | ||||||
|  |  | ||||||
| @ -27,7 +26,6 @@ class SCIMTokenAuth(BaseAuthentication): | |||||||
|         _username, _, password = b64decode(key.encode()).decode().partition(":") |         _username, _, password = b64decode(key.encode()).decode().partition(":") | ||||||
|         token = self.check_token(password, source_slug) |         token = self.check_token(password, source_slug) | ||||||
|         if token: |         if token: | ||||||
|             CTX_AUTH_VIA.set("scim_basic") |  | ||||||
|             return (token.user, token) |             return (token.user, token) | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
| @ -54,5 +52,4 @@ class SCIMTokenAuth(BaseAuthentication): | |||||||
|         token = self.check_token(key, source_slug) |         token = self.check_token(key, source_slug) | ||||||
|         if not token: |         if not token: | ||||||
|             return None |             return None | ||||||
|         CTX_AUTH_VIA.set("scim_token") |  | ||||||
|         return (token.user, token) |         return (token.user, token) | ||||||
|  | |||||||
| @ -1,11 +1,13 @@ | |||||||
| """SCIM Utils""" | """SCIM Utils""" | ||||||
|  |  | ||||||
| from typing import Any | from typing import Any | ||||||
|  | from urllib.parse import urlparse | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.core.paginator import Page, Paginator | from django.core.paginator import Page, Paginator | ||||||
| from django.db.models import Q, QuerySet | from django.db.models import Q, QuerySet | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
|  | from django.urls import resolve | ||||||
| from rest_framework.parsers import JSONParser | from rest_framework.parsers import JSONParser | ||||||
| from rest_framework.permissions import IsAuthenticated | from rest_framework.permissions import IsAuthenticated | ||||||
| from rest_framework.renderers import JSONRenderer | from rest_framework.renderers import JSONRenderer | ||||||
| @ -44,7 +46,7 @@ class SCIMView(APIView): | |||||||
|     logger: BoundLogger |     logger: BoundLogger | ||||||
|  |  | ||||||
|     permission_classes = [IsAuthenticated] |     permission_classes = [IsAuthenticated] | ||||||
|     parser_classes = [SCIMParser, JSONParser] |     parser_classes = [SCIMParser] | ||||||
|     renderer_classes = [SCIMRenderer] |     renderer_classes = [SCIMRenderer] | ||||||
|  |  | ||||||
|     def setup(self, request: HttpRequest, *args: Any, **kwargs: Any) -> None: |     def setup(self, request: HttpRequest, *args: Any, **kwargs: Any) -> None: | ||||||
| @ -54,6 +56,28 @@ class SCIMView(APIView): | |||||||
|     def get_authenticators(self): |     def get_authenticators(self): | ||||||
|         return [SCIMTokenAuth(self)] |         return [SCIMTokenAuth(self)] | ||||||
|  |  | ||||||
|  |     def patch_resolve_value(self, raw_value: dict) -> User | Group | None: | ||||||
|  |         """Attempt to resolve a raw `value` attribute of a patch operation into | ||||||
|  |         a database model""" | ||||||
|  |         model = User | ||||||
|  |         query = {} | ||||||
|  |         if "$ref" in raw_value: | ||||||
|  |             url = urlparse(raw_value["$ref"]) | ||||||
|  |             if match := resolve(url.path): | ||||||
|  |                 if match.url_name == "v2-users": | ||||||
|  |                     model = User | ||||||
|  |                     query = {"pk": int(match.kwargs["user_id"])} | ||||||
|  |         elif "type" in raw_value: | ||||||
|  |             match raw_value["type"]: | ||||||
|  |                 case "User": | ||||||
|  |                     model = User | ||||||
|  |                     query = {"pk": int(raw_value["value"])} | ||||||
|  |                 case "Group": | ||||||
|  |                     model = Group | ||||||
|  |         else: | ||||||
|  |             return None | ||||||
|  |         return model.objects.filter(**query).first() | ||||||
|  |  | ||||||
|     def filter_parse(self, request: Request): |     def filter_parse(self, request: Request): | ||||||
|         """Parse the path of a Patch Operation""" |         """Parse the path of a Patch Operation""" | ||||||
|         path = request.query_params.get("filter") |         path = request.query_params.get("filter") | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	