Compare commits
	
		
			1 Commits
		
	
	
		
			imports-fo
			...
			flows/buff
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| b0e6558a4f | 
| @ -1,5 +1,5 @@ | |||||||
| [bumpversion] | [bumpversion] | ||||||
| current_version = 2025.6.2 | current_version = 2025.6.1 | ||||||
| tag = True | tag = True | ||||||
| commit = True | commit = True | ||||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | ||||||
| @ -21,8 +21,6 @@ optional_value = final | |||||||
|  |  | ||||||
| [bumpversion:file:package.json] | [bumpversion:file:package.json] | ||||||
|  |  | ||||||
| [bumpversion:file:package-lock.json] |  | ||||||
|  |  | ||||||
| [bumpversion:file:docker-compose.yml] | [bumpversion:file:docker-compose.yml] | ||||||
|  |  | ||||||
| [bumpversion:file:schema.yml] | [bumpversion:file:schema.yml] | ||||||
| @ -33,4 +31,6 @@ optional_value = final | |||||||
|  |  | ||||||
| [bumpversion:file:internal/constants/constants.go] | [bumpversion:file:internal/constants/constants.go] | ||||||
|  |  | ||||||
|  | [bumpversion:file:web/src/common/constants.ts] | ||||||
|  |  | ||||||
| [bumpversion:file:lifecycle/aws/template.yaml] | [bumpversion:file:lifecycle/aws/template.yaml] | ||||||
|  | |||||||
| @ -7,9 +7,6 @@ charset = utf-8 | |||||||
| trim_trailing_whitespace = true | trim_trailing_whitespace = true | ||||||
| insert_final_newline = true | insert_final_newline = true | ||||||
|  |  | ||||||
| [*.toml] |  | ||||||
| indent_size = 2 |  | ||||||
|  |  | ||||||
| [*.html] | [*.html] | ||||||
| indent_size = 2 | indent_size = 2 | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/workflows/ci-main-daily.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci-main-daily.yml
									
									
									
									
										vendored
									
									
								
							| @ -15,8 +15,8 @@ jobs: | |||||||
|       matrix: |       matrix: | ||||||
|         version: |         version: | ||||||
|           - docs |           - docs | ||||||
|           - version-2025-4 |  | ||||||
|           - version-2025-2 |           - version-2025-2 | ||||||
|  |           - version-2024-12 | ||||||
|     steps: |     steps: | ||||||
|       - uses: actions/checkout@v4 |       - uses: actions/checkout@v4 | ||||||
|       - run: | |       - run: | | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -202,7 +202,7 @@ jobs: | |||||||
|         uses: actions/cache@v4 |         uses: actions/cache@v4 | ||||||
|         with: |         with: | ||||||
|           path: web/dist |           path: web/dist | ||||||
|           key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b |           key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b | ||||||
|       - name: prepare web ui |       - name: prepare web ui | ||||||
|         if: steps.cache-web.outputs.cache-hit != 'true' |         if: steps.cache-web.outputs.cache-hit != 'true' | ||||||
|         working-directory: web |         working-directory: web | ||||||
|  | |||||||
							
								
								
									
										22
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										22
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							| @ -41,27 +41,6 @@ jobs: | |||||||
|       - name: test |       - name: test | ||||||
|         working-directory: website/ |         working-directory: website/ | ||||||
|         run: npm test |         run: npm test | ||||||
|   build: |  | ||||||
|     runs-on: ubuntu-latest |  | ||||||
|     name: ${{ matrix.job }} |  | ||||||
|     strategy: |  | ||||||
|       fail-fast: false |  | ||||||
|       matrix: |  | ||||||
|         job: |  | ||||||
|           - build |  | ||||||
|           - build:integrations |  | ||||||
|     steps: |  | ||||||
|       - uses: actions/checkout@v4 |  | ||||||
|       - uses: actions/setup-node@v4 |  | ||||||
|         with: |  | ||||||
|           node-version-file: website/package.json |  | ||||||
|           cache: "npm" |  | ||||||
|           cache-dependency-path: website/package-lock.json |  | ||||||
|       - working-directory: website/ |  | ||||||
|         run: npm ci |  | ||||||
|       - name: build |  | ||||||
|         working-directory: website/ |  | ||||||
|         run: npm run ${{ matrix.job }} |  | ||||||
|   build-container: |   build-container: | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     permissions: |     permissions: | ||||||
| @ -115,7 +94,6 @@ jobs: | |||||||
|     needs: |     needs: | ||||||
|       - lint |       - lint | ||||||
|       - test |       - test | ||||||
|       - build |  | ||||||
|       - build-container |       - build-container | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/workflows/codeql-analysis.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/codeql-analysis.yml
									
									
									
									
										vendored
									
									
								
							| @ -2,7 +2,7 @@ name: "CodeQL" | |||||||
|  |  | ||||||
| on: | on: | ||||||
|   push: |   push: | ||||||
|     branches: [main, next, version*] |     branches: [main, "*", next, version*] | ||||||
|   pull_request: |   pull_request: | ||||||
|     branches: [main] |     branches: [main] | ||||||
|   schedule: |   schedule: | ||||||
|  | |||||||
							
								
								
									
										4
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
								
							| @ -6,15 +6,13 @@ | |||||||
|         "!Context scalar", |         "!Context scalar", | ||||||
|         "!Enumerate sequence", |         "!Enumerate sequence", | ||||||
|         "!Env scalar", |         "!Env scalar", | ||||||
|         "!Env sequence", |  | ||||||
|         "!Find sequence", |         "!Find sequence", | ||||||
|         "!Format sequence", |         "!Format sequence", | ||||||
|         "!If sequence", |         "!If sequence", | ||||||
|         "!Index scalar", |         "!Index scalar", | ||||||
|         "!KeyOf scalar", |         "!KeyOf scalar", | ||||||
|         "!Value scalar", |         "!Value scalar", | ||||||
|         "!AtIndex scalar", |         "!AtIndex scalar" | ||||||
|         "!ParseJSON scalar" |  | ||||||
|     ], |     ], | ||||||
|     "typescript.preferences.importModuleSpecifier": "non-relative", |     "typescript.preferences.importModuleSpecifier": "non-relative", | ||||||
|     "typescript.preferences.importModuleSpecifierEnding": "index", |     "typescript.preferences.importModuleSpecifierEnding": "index", | ||||||
|  | |||||||
| @ -75,9 +75,9 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \ | |||||||
|     /bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" |     /bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" | ||||||
|  |  | ||||||
| # Stage 4: Download uv | # Stage 4: Download uv | ||||||
| FROM ghcr.io/astral-sh/uv:0.7.15 AS uv | FROM ghcr.io/astral-sh/uv:0.7.13 AS uv | ||||||
| # Stage 5: Base python image | # Stage 5: Base python image | ||||||
| FROM ghcr.io/goauthentik/fips-python:3.13.5-slim-bookworm-fips AS python-base | FROM ghcr.io/goauthentik/fips-python:3.13.4-slim-bookworm-fips AS python-base | ||||||
|  |  | ||||||
| ENV VENV_PATH="/ak-root/.venv" \ | ENV VENV_PATH="/ak-root/.venv" \ | ||||||
|     PATH="/lifecycle:/ak-root/.venv/bin:$PATH" \ |     PATH="/lifecycle:/ak-root/.venv/bin:$PATH" \ | ||||||
|  | |||||||
							
								
								
									
										12
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								Makefile
									
									
									
									
									
								
							| @ -86,10 +86,6 @@ dev-create-db: | |||||||
|  |  | ||||||
| dev-reset: dev-drop-db dev-create-db migrate  ## Drop and restore the Authentik PostgreSQL instance to a "fresh install" state. | dev-reset: dev-drop-db dev-create-db migrate  ## Drop and restore the Authentik PostgreSQL instance to a "fresh install" state. | ||||||
|  |  | ||||||
| update-test-mmdb:  ## Update test GeoIP and ASN Databases |  | ||||||
| 	curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-ASN-Test.mmdb -o ${PWD}/tests/GeoLite2-ASN-Test.mmdb |  | ||||||
| 	curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-City-Test.mmdb -o ${PWD}/tests/GeoLite2-City-Test.mmdb |  | ||||||
|  |  | ||||||
| ######################### | ######################### | ||||||
| ## API Schema | ## API Schema | ||||||
| ######################### | ######################### | ||||||
| @ -98,7 +94,7 @@ gen-build:  ## Extract the schema from the database | |||||||
| 	AUTHENTIK_DEBUG=true \ | 	AUTHENTIK_DEBUG=true \ | ||||||
| 		AUTHENTIK_TENANTS__ENABLED=true \ | 		AUTHENTIK_TENANTS__ENABLED=true \ | ||||||
| 		AUTHENTIK_OUTPOSTS__DISABLE_EMBEDDED_OUTPOST=true \ | 		AUTHENTIK_OUTPOSTS__DISABLE_EMBEDDED_OUTPOST=true \ | ||||||
| 		uv run ak make_blueprint_schema --file blueprints/schema.json | 		uv run ak make_blueprint_schema > blueprints/schema.json | ||||||
| 	AUTHENTIK_DEBUG=true \ | 	AUTHENTIK_DEBUG=true \ | ||||||
| 		AUTHENTIK_TENANTS__ENABLED=true \ | 		AUTHENTIK_TENANTS__ENABLED=true \ | ||||||
| 		AUTHENTIK_OUTPOSTS__DISABLE_EMBEDDED_OUTPOST=true \ | 		AUTHENTIK_OUTPOSTS__DISABLE_EMBEDDED_OUTPOST=true \ | ||||||
| @ -150,9 +146,9 @@ gen-client-ts: gen-clean-ts  ## Build and install the authentik API for Typescri | |||||||
| 		--additional-properties=npmVersion=${NPM_VERSION} \ | 		--additional-properties=npmVersion=${NPM_VERSION} \ | ||||||
| 		--git-repo-id authentik \ | 		--git-repo-id authentik \ | ||||||
| 		--git-user-id goauthentik | 		--git-user-id goauthentik | ||||||
|  | 	mkdir -p web/node_modules/@goauthentik/api | ||||||
| 	cd ${PWD}/${GEN_API_TS} && npm link | 	cd ${PWD}/${GEN_API_TS} && npm i | ||||||
| 	cd ${PWD}/web && npm link @goauthentik/api | 	\cp -rf ${PWD}/${GEN_API_TS}/* web/node_modules/@goauthentik/api | ||||||
|  |  | ||||||
| gen-client-py: gen-clean-py ## Build and install the authentik API for Python | gen-client-py: gen-clean-py ## Build and install the authentik API for Python | ||||||
| 	docker run \ | 	docker run \ | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
|  |  | ||||||
| from os import environ | from os import environ | ||||||
|  |  | ||||||
| __version__ = "2025.6.2" | __version__ = "2025.6.1" | ||||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -72,33 +72,20 @@ class Command(BaseCommand): | |||||||
|                     "additionalProperties": True, |                     "additionalProperties": True, | ||||||
|                 }, |                 }, | ||||||
|                 "entries": { |                 "entries": { | ||||||
|                     "anyOf": [ |                     "type": "array", | ||||||
|                         { |                     "items": { | ||||||
|                             "type": "array", |                         "oneOf": [], | ||||||
|                             "items": {"$ref": "#/$defs/blueprint_entry"}, |                     }, | ||||||
|                         }, |  | ||||||
|                         { |  | ||||||
|                             "type": "object", |  | ||||||
|                             "additionalProperties": { |  | ||||||
|                                 "type": "array", |  | ||||||
|                                 "items": {"$ref": "#/$defs/blueprint_entry"}, |  | ||||||
|                             }, |  | ||||||
|                         }, |  | ||||||
|                     ], |  | ||||||
|                 }, |                 }, | ||||||
|             }, |             }, | ||||||
|             "$defs": {"blueprint_entry": {"oneOf": []}}, |             "$defs": {}, | ||||||
|         } |         } | ||||||
|  |  | ||||||
|     def add_arguments(self, parser): |  | ||||||
|         parser.add_argument("--file", type=str) |  | ||||||
|  |  | ||||||
|     @no_translations |     @no_translations | ||||||
|     def handle(self, *args, file: str, **options): |     def handle(self, *args, **options): | ||||||
|         """Generate JSON Schema for blueprints""" |         """Generate JSON Schema for blueprints""" | ||||||
|         self.build() |         self.build() | ||||||
|         with open(file, "w") as _schema: |         self.stdout.write(dumps(self.schema, indent=4, default=Command.json_default)) | ||||||
|             _schema.write(dumps(self.schema, indent=4, default=Command.json_default)) |  | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def json_default(value: Any) -> Any: |     def json_default(value: Any) -> Any: | ||||||
| @ -125,7 +112,7 @@ class Command(BaseCommand): | |||||||
|                 } |                 } | ||||||
|             ) |             ) | ||||||
|             model_path = f"{model._meta.app_label}.{model._meta.model_name}" |             model_path = f"{model._meta.app_label}.{model._meta.model_name}" | ||||||
|             self.schema["$defs"]["blueprint_entry"]["oneOf"].append( |             self.schema["properties"]["entries"]["items"]["oneOf"].append( | ||||||
|                 self.template_entry(model_path, model, serializer) |                 self.template_entry(model_path, model, serializer) | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,11 +1,10 @@ | |||||||
| version: 1 | version: 1 | ||||||
| entries: | entries: | ||||||
|   foo: |     - identifiers: | ||||||
|       - identifiers: |           name: "%(id)s" | ||||||
|             name: "%(id)s" |           slug: "%(id)s" | ||||||
|             slug: "%(id)s" |       model: authentik_flows.flow | ||||||
|         model: authentik_flows.flow |       state: present | ||||||
|         state: present |       attrs: | ||||||
|         attrs: |           designation: stage_configuration | ||||||
|             designation: stage_configuration |           title: foo | ||||||
|             title: foo |  | ||||||
|  | |||||||
| @ -37,7 +37,6 @@ entries: | |||||||
|     - attrs: |     - attrs: | ||||||
|           attributes: |           attributes: | ||||||
|               env_null: !Env [bar-baz, null] |               env_null: !Env [bar-baz, null] | ||||||
|               json_parse: !ParseJSON '{"foo": "bar"}' |  | ||||||
|               policy_pk1: |               policy_pk1: | ||||||
|                   !Format [ |                   !Format [ | ||||||
|                       "%s-%s", |                       "%s-%s", | ||||||
|  | |||||||
| @ -35,6 +35,6 @@ def blueprint_tester(file_name: Path) -> Callable: | |||||||
|  |  | ||||||
|  |  | ||||||
| for blueprint_file in Path("blueprints/").glob("**/*.yaml"): | for blueprint_file in Path("blueprints/").glob("**/*.yaml"): | ||||||
|     if "local" in str(blueprint_file) or "testing" in str(blueprint_file): |     if "local" in str(blueprint_file): | ||||||
|         continue |         continue | ||||||
|     setattr(TestPackaged, f"test_blueprint_{blueprint_file}", blueprint_tester(blueprint_file)) |     setattr(TestPackaged, f"test_blueprint_{blueprint_file}", blueprint_tester(blueprint_file)) | ||||||
|  | |||||||
| @ -215,7 +215,6 @@ class TestBlueprintsV1(TransactionTestCase): | |||||||
|                     }, |                     }, | ||||||
|                     "nested_context": "context-nested-value", |                     "nested_context": "context-nested-value", | ||||||
|                     "env_null": None, |                     "env_null": None, | ||||||
|                     "json_parse": {"foo": "bar"}, |  | ||||||
|                     "at_index_sequence": "foo", |                     "at_index_sequence": "foo", | ||||||
|                     "at_index_sequence_default": "non existent", |                     "at_index_sequence_default": "non existent", | ||||||
|                     "at_index_mapping": 2, |                     "at_index_mapping": 2, | ||||||
|  | |||||||
| @ -6,7 +6,6 @@ from copy import copy | |||||||
| from dataclasses import asdict, dataclass, field, is_dataclass | from dataclasses import asdict, dataclass, field, is_dataclass | ||||||
| from enum import Enum | from enum import Enum | ||||||
| from functools import reduce | from functools import reduce | ||||||
| from json import JSONDecodeError, loads |  | ||||||
| from operator import ixor | from operator import ixor | ||||||
| from os import getenv | from os import getenv | ||||||
| from typing import Any, Literal, Union | from typing import Any, Literal, Union | ||||||
| @ -192,18 +191,11 @@ class Blueprint: | |||||||
|     """Dataclass used for a full export""" |     """Dataclass used for a full export""" | ||||||
|  |  | ||||||
|     version: int = field(default=1) |     version: int = field(default=1) | ||||||
|     entries: list[BlueprintEntry] | dict[str, list[BlueprintEntry]] = field(default_factory=list) |     entries: list[BlueprintEntry] = field(default_factory=list) | ||||||
|     context: dict = field(default_factory=dict) |     context: dict = field(default_factory=dict) | ||||||
|  |  | ||||||
|     metadata: BlueprintMetadata | None = field(default=None) |     metadata: BlueprintMetadata | None = field(default=None) | ||||||
|  |  | ||||||
|     def iter_entries(self) -> Iterable[BlueprintEntry]: |  | ||||||
|         if isinstance(self.entries, dict): |  | ||||||
|             for _section, entries in self.entries.items(): |  | ||||||
|                 yield from entries |  | ||||||
|         else: |  | ||||||
|             yield from self.entries |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class YAMLTag: | class YAMLTag: | ||||||
|     """Base class for all YAML Tags""" |     """Base class for all YAML Tags""" | ||||||
| @ -234,7 +226,7 @@ class KeyOf(YAMLTag): | |||||||
|         self.id_from = node.value |         self.id_from = node.value | ||||||
|  |  | ||||||
|     def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: |     def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: | ||||||
|         for _entry in blueprint.iter_entries(): |         for _entry in blueprint.entries: | ||||||
|             if _entry.id == self.id_from and _entry._state.instance: |             if _entry.id == self.id_from and _entry._state.instance: | ||||||
|                 # Special handling for PolicyBindingModels, as they'll have a different PK |                 # Special handling for PolicyBindingModels, as they'll have a different PK | ||||||
|                 # which is used when creating policy bindings |                 # which is used when creating policy bindings | ||||||
| @ -292,22 +284,6 @@ class Context(YAMLTag): | |||||||
|         return value |         return value | ||||||
|  |  | ||||||
|  |  | ||||||
| class ParseJSON(YAMLTag): |  | ||||||
|     """Parse JSON from context/env/etc value""" |  | ||||||
|  |  | ||||||
|     raw: str |  | ||||||
|  |  | ||||||
|     def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None: |  | ||||||
|         super().__init__() |  | ||||||
|         self.raw = node.value |  | ||||||
|  |  | ||||||
|     def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: |  | ||||||
|         try: |  | ||||||
|             return loads(self.raw) |  | ||||||
|         except JSONDecodeError as exc: |  | ||||||
|             raise EntryInvalidError.from_entry(exc, entry) from exc |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Format(YAMLTag): | class Format(YAMLTag): | ||||||
|     """Format a string""" |     """Format a string""" | ||||||
|  |  | ||||||
| @ -683,7 +659,6 @@ class BlueprintLoader(SafeLoader): | |||||||
|         self.add_constructor("!Value", Value) |         self.add_constructor("!Value", Value) | ||||||
|         self.add_constructor("!Index", Index) |         self.add_constructor("!Index", Index) | ||||||
|         self.add_constructor("!AtIndex", AtIndex) |         self.add_constructor("!AtIndex", AtIndex) | ||||||
|         self.add_constructor("!ParseJSON", ParseJSON) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class EntryInvalidError(SentryIgnoredException): | class EntryInvalidError(SentryIgnoredException): | ||||||
|  | |||||||
| @ -384,7 +384,7 @@ class Importer: | |||||||
|     def _apply_models(self, raise_errors=False) -> bool: |     def _apply_models(self, raise_errors=False) -> bool: | ||||||
|         """Apply (create/update) models yaml""" |         """Apply (create/update) models yaml""" | ||||||
|         self.__pk_map = {} |         self.__pk_map = {} | ||||||
|         for entry in self._import.iter_entries(): |         for entry in self._import.entries: | ||||||
|             model_app_label, model_name = entry.get_model(self._import).split(".") |             model_app_label, model_name = entry.get_model(self._import).split(".") | ||||||
|             try: |             try: | ||||||
|                 model: type[SerializerModel] = registry.get_model(model_app_label, model_name) |                 model: type[SerializerModel] = registry.get_model(model_app_label, model_name) | ||||||
|  | |||||||
| @ -1,6 +1,8 @@ | |||||||
| """Authenticator Devices API Views""" | """Authenticator Devices API Views""" | ||||||
|  |  | ||||||
| from drf_spectacular.utils import extend_schema | from django.utils.translation import gettext_lazy as _ | ||||||
|  | from drf_spectacular.types import OpenApiTypes | ||||||
|  | from drf_spectacular.utils import OpenApiParameter, extend_schema | ||||||
| from guardian.shortcuts import get_objects_for_user | from guardian.shortcuts import get_objects_for_user | ||||||
| from rest_framework.fields import ( | from rest_framework.fields import ( | ||||||
|     BooleanField, |     BooleanField, | ||||||
| @ -13,7 +15,6 @@ from rest_framework.request import Request | |||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
| from rest_framework.viewsets import ViewSet | from rest_framework.viewsets import ViewSet | ||||||
|  |  | ||||||
| from authentik.core.api.users import ParamUserSerializer |  | ||||||
| from authentik.core.api.utils import MetaNameSerializer | from authentik.core.api.utils import MetaNameSerializer | ||||||
| from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import EndpointDevice | from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import EndpointDevice | ||||||
| from authentik.stages.authenticator import device_classes, devices_for_user | from authentik.stages.authenticator import device_classes, devices_for_user | ||||||
| @ -22,7 +23,7 @@ from authentik.stages.authenticator_webauthn.models import WebAuthnDevice | |||||||
|  |  | ||||||
|  |  | ||||||
| class DeviceSerializer(MetaNameSerializer): | class DeviceSerializer(MetaNameSerializer): | ||||||
|     """Serializer for authenticator devices""" |     """Serializer for Duo authenticator devices""" | ||||||
|  |  | ||||||
|     pk = CharField() |     pk = CharField() | ||||||
|     name = CharField() |     name = CharField() | ||||||
| @ -32,27 +33,22 @@ class DeviceSerializer(MetaNameSerializer): | |||||||
|     last_updated = DateTimeField(read_only=True) |     last_updated = DateTimeField(read_only=True) | ||||||
|     last_used = DateTimeField(read_only=True, allow_null=True) |     last_used = DateTimeField(read_only=True, allow_null=True) | ||||||
|     extra_description = SerializerMethodField() |     extra_description = SerializerMethodField() | ||||||
|     external_id = SerializerMethodField() |  | ||||||
|  |  | ||||||
|     def get_type(self, instance: Device) -> str: |     def get_type(self, instance: Device) -> str: | ||||||
|         """Get type of device""" |         """Get type of device""" | ||||||
|         return instance._meta.label |         return instance._meta.label | ||||||
|  |  | ||||||
|     def get_extra_description(self, instance: Device) -> str | None: |     def get_extra_description(self, instance: Device) -> str: | ||||||
|         """Get extra description""" |         """Get extra description""" | ||||||
|         if isinstance(instance, WebAuthnDevice): |         if isinstance(instance, WebAuthnDevice): | ||||||
|             return instance.device_type.description if instance.device_type else None |             return ( | ||||||
|  |                 instance.device_type.description | ||||||
|  |                 if instance.device_type | ||||||
|  |                 else _("Extra description not available") | ||||||
|  |             ) | ||||||
|         if isinstance(instance, EndpointDevice): |         if isinstance(instance, EndpointDevice): | ||||||
|             return instance.data.get("deviceSignals", {}).get("deviceModel") |             return instance.data.get("deviceSignals", {}).get("deviceModel") | ||||||
|         return None |         return "" | ||||||
|  |  | ||||||
|     def get_external_id(self, instance: Device) -> str | None: |  | ||||||
|         """Get external Device ID""" |  | ||||||
|         if isinstance(instance, WebAuthnDevice): |  | ||||||
|             return instance.device_type.aaguid if instance.device_type else None |  | ||||||
|         if isinstance(instance, EndpointDevice): |  | ||||||
|             return instance.data.get("deviceSignals", {}).get("deviceModel") |  | ||||||
|         return None |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DeviceViewSet(ViewSet): | class DeviceViewSet(ViewSet): | ||||||
| @ -61,6 +57,7 @@ class DeviceViewSet(ViewSet): | |||||||
|     serializer_class = DeviceSerializer |     serializer_class = DeviceSerializer | ||||||
|     permission_classes = [IsAuthenticated] |     permission_classes = [IsAuthenticated] | ||||||
|  |  | ||||||
|  |     @extend_schema(responses={200: DeviceSerializer(many=True)}) | ||||||
|     def list(self, request: Request) -> Response: |     def list(self, request: Request) -> Response: | ||||||
|         """Get all devices for current user""" |         """Get all devices for current user""" | ||||||
|         devices = devices_for_user(request.user) |         devices = devices_for_user(request.user) | ||||||
| @ -82,11 +79,18 @@ class AdminDeviceViewSet(ViewSet): | |||||||
|             yield from device_set |             yield from device_set | ||||||
|  |  | ||||||
|     @extend_schema( |     @extend_schema( | ||||||
|         parameters=[ParamUserSerializer], |         parameters=[ | ||||||
|  |             OpenApiParameter( | ||||||
|  |                 name="user", | ||||||
|  |                 location=OpenApiParameter.QUERY, | ||||||
|  |                 type=OpenApiTypes.INT, | ||||||
|  |             ) | ||||||
|  |         ], | ||||||
|         responses={200: DeviceSerializer(many=True)}, |         responses={200: DeviceSerializer(many=True)}, | ||||||
|     ) |     ) | ||||||
|     def list(self, request: Request) -> Response: |     def list(self, request: Request) -> Response: | ||||||
|         """Get all devices for current user""" |         """Get all devices for current user""" | ||||||
|         args = ParamUserSerializer(data=request.query_params) |         kwargs = {} | ||||||
|         args.is_valid(raise_exception=True) |         if "user" in request.query_params: | ||||||
|         return Response(DeviceSerializer(self.get_devices(**args.validated_data), many=True).data) |             kwargs = {"user": request.query_params["user"]} | ||||||
|  |         return Response(DeviceSerializer(self.get_devices(**kwargs), many=True).data) | ||||||
|  | |||||||
| @ -90,12 +90,6 @@ from authentik.stages.email.utils import TemplateEmailMessage | |||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| class ParamUserSerializer(PassiveSerializer): |  | ||||||
|     """Partial serializer for query parameters to select a user""" |  | ||||||
|  |  | ||||||
|     user = PrimaryKeyRelatedField(queryset=User.objects.all().exclude_anonymous(), required=False) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class UserGroupSerializer(ModelSerializer): | class UserGroupSerializer(ModelSerializer): | ||||||
|     """Simplified Group Serializer for user's groups""" |     """Simplified Group Serializer for user's groups""" | ||||||
|  |  | ||||||
| @ -392,23 +386,8 @@ class UserViewSet(UsedByMixin, ModelViewSet): | |||||||
|     queryset = User.objects.none() |     queryset = User.objects.none() | ||||||
|     ordering = ["username"] |     ordering = ["username"] | ||||||
|     serializer_class = UserSerializer |     serializer_class = UserSerializer | ||||||
|     filterset_class = UsersFilter |  | ||||||
|     search_fields = ["username", "name", "is_active", "email", "uuid", "attributes"] |     search_fields = ["username", "name", "is_active", "email", "uuid", "attributes"] | ||||||
|  |     filterset_class = UsersFilter | ||||||
|     def get_ql_fields(self): |  | ||||||
|         from djangoql.schema import BoolField, StrField |  | ||||||
|  |  | ||||||
|         from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField |  | ||||||
|  |  | ||||||
|         return [ |  | ||||||
|             StrField(User, "username"), |  | ||||||
|             StrField(User, "name"), |  | ||||||
|             StrField(User, "email"), |  | ||||||
|             StrField(User, "path"), |  | ||||||
|             BoolField(User, "is_active", nullable=True), |  | ||||||
|             ChoiceSearchField(User, "type"), |  | ||||||
|             JSONSearchField(User, "attributes", suggest_nested=False), |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|     def get_queryset(self): |     def get_queryset(self): | ||||||
|         base_qs = User.objects.all().exclude_anonymous() |         base_qs = User.objects.all().exclude_anonymous() | ||||||
|  | |||||||
| @ -2,7 +2,6 @@ | |||||||
|  |  | ||||||
| from typing import Any | from typing import Any | ||||||
|  |  | ||||||
| from django.db import models |  | ||||||
| from django.db.models import Model | from django.db.models import Model | ||||||
| from drf_spectacular.extensions import OpenApiSerializerFieldExtension | from drf_spectacular.extensions import OpenApiSerializerFieldExtension | ||||||
| from drf_spectacular.plumbing import build_basic_type | from drf_spectacular.plumbing import build_basic_type | ||||||
| @ -31,27 +30,7 @@ def is_dict(value: Any): | |||||||
|     raise ValidationError("Value must be a dictionary, and not have any duplicate keys.") |     raise ValidationError("Value must be a dictionary, and not have any duplicate keys.") | ||||||
|  |  | ||||||
|  |  | ||||||
| class JSONDictField(JSONField): |  | ||||||
|     """JSON Field which only allows dictionaries""" |  | ||||||
|  |  | ||||||
|     default_validators = [is_dict] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class JSONExtension(OpenApiSerializerFieldExtension): |  | ||||||
|     """Generate API Schema for JSON fields as""" |  | ||||||
|  |  | ||||||
|     target_class = "authentik.core.api.utils.JSONDictField" |  | ||||||
|  |  | ||||||
|     def map_serializer_field(self, auto_schema, direction): |  | ||||||
|         return build_basic_type(OpenApiTypes.OBJECT) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ModelSerializer(BaseModelSerializer): | class ModelSerializer(BaseModelSerializer): | ||||||
|  |  | ||||||
|     # By default, JSON fields we have are used to store dictionaries |  | ||||||
|     serializer_field_mapping = BaseModelSerializer.serializer_field_mapping.copy() |  | ||||||
|     serializer_field_mapping[models.JSONField] = JSONDictField |  | ||||||
|  |  | ||||||
|     def create(self, validated_data): |     def create(self, validated_data): | ||||||
|         instance = super().create(validated_data) |         instance = super().create(validated_data) | ||||||
|  |  | ||||||
| @ -92,6 +71,21 @@ class ModelSerializer(BaseModelSerializer): | |||||||
|         return instance |         return instance | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class JSONDictField(JSONField): | ||||||
|  |     """JSON Field which only allows dictionaries""" | ||||||
|  |  | ||||||
|  |     default_validators = [is_dict] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class JSONExtension(OpenApiSerializerFieldExtension): | ||||||
|  |     """Generate API Schema for JSON fields as""" | ||||||
|  |  | ||||||
|  |     target_class = "authentik.core.api.utils.JSONDictField" | ||||||
|  |  | ||||||
|  |     def map_serializer_field(self, auto_schema, direction): | ||||||
|  |         return build_basic_type(OpenApiTypes.OBJECT) | ||||||
|  |  | ||||||
|  |  | ||||||
| class PassiveSerializer(Serializer): | class PassiveSerializer(Serializer): | ||||||
|     """Base serializer class which doesn't implement create/update methods""" |     """Base serializer class which doesn't implement create/update methods""" | ||||||
|  |  | ||||||
|  | |||||||
| @ -13,6 +13,7 @@ class Command(TenantCommand): | |||||||
|         parser.add_argument("usernames", nargs="*", type=str) |         parser.add_argument("usernames", nargs="*", type=str) | ||||||
|  |  | ||||||
|     def handle_per_tenant(self, **options): |     def handle_per_tenant(self, **options): | ||||||
|  |         print(options) | ||||||
|         new_type = UserTypes(options["type"]) |         new_type = UserTypes(options["type"]) | ||||||
|         qs = ( |         qs = ( | ||||||
|             User.objects.exclude_anonymous() |             User.objects.exclude_anonymous() | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ from django.http import HttpRequest | |||||||
| from django.utils.functional import SimpleLazyObject, cached_property | from django.utils.functional import SimpleLazyObject, cached_property | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from django_cte import CTE, with_cte | from django_cte import CTEQuerySet, With | ||||||
| from guardian.conf import settings | from guardian.conf import settings | ||||||
| from guardian.mixins import GuardianUserMixin | from guardian.mixins import GuardianUserMixin | ||||||
| from model_utils.managers import InheritanceManager | from model_utils.managers import InheritanceManager | ||||||
| @ -136,7 +136,7 @@ class AttributesMixin(models.Model): | |||||||
|         return instance, False |         return instance, False | ||||||
|  |  | ||||||
|  |  | ||||||
| class GroupQuerySet(QuerySet): | class GroupQuerySet(CTEQuerySet): | ||||||
|     def with_children_recursive(self): |     def with_children_recursive(self): | ||||||
|         """Recursively get all groups that have the current queryset as parents |         """Recursively get all groups that have the current queryset as parents | ||||||
|         or are indirectly related.""" |         or are indirectly related.""" | ||||||
| @ -165,9 +165,9 @@ class GroupQuerySet(QuerySet): | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         # Build the recursive query, see above |         # Build the recursive query, see above | ||||||
|         cte = CTE.recursive(make_cte) |         cte = With.recursive(make_cte) | ||||||
|         # Return the result, as a usable queryset for Group. |         # Return the result, as a usable queryset for Group. | ||||||
|         return with_cte(cte, select=cte.join(Group, group_uuid=cte.col.group_uuid)) |         return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Group(SerializerModel, AttributesMixin): | class Group(SerializerModel, AttributesMixin): | ||||||
|  | |||||||
| @ -114,7 +114,6 @@ class TestApplicationsAPI(APITestCase): | |||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             response.content.decode(), |             response.content.decode(), | ||||||
|             { |             { | ||||||
|                 "autocomplete": {}, |  | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                     "next": 0, |                     "next": 0, | ||||||
|                     "previous": 0, |                     "previous": 0, | ||||||
| @ -168,7 +167,6 @@ class TestApplicationsAPI(APITestCase): | |||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             response.content.decode(), |             response.content.decode(), | ||||||
|             { |             { | ||||||
|                 "autocomplete": {}, |  | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                     "next": 0, |                     "next": 0, | ||||||
|                     "previous": 0, |                     "previous": 0, | ||||||
|  | |||||||
| @ -119,17 +119,17 @@ class TestTrimPasswordHistory(TestCase): | |||||||
|             [ |             [ | ||||||
|                 UserPasswordHistory( |                 UserPasswordHistory( | ||||||
|                     user=self.user, |                     user=self.user, | ||||||
|                     old_password="hunter1",  # nosec |                     old_password="hunter1",  # nosec B106 | ||||||
|                     created_at=_now - timedelta(days=3), |                     created_at=_now - timedelta(days=3), | ||||||
|                 ), |                 ), | ||||||
|                 UserPasswordHistory( |                 UserPasswordHistory( | ||||||
|                     user=self.user, |                     user=self.user, | ||||||
|                     old_password="hunter2",  # nosec |                     old_password="hunter2",  # nosec B106 | ||||||
|                     created_at=_now - timedelta(days=2), |                     created_at=_now - timedelta(days=2), | ||||||
|                 ), |                 ), | ||||||
|                 UserPasswordHistory( |                 UserPasswordHistory( | ||||||
|                     user=self.user, |                     user=self.user, | ||||||
|                     old_password="hunter3",  # nosec |                     old_password="hunter3",  # nosec B106 | ||||||
|                     created_at=_now, |                     created_at=_now, | ||||||
|                 ), |                 ), | ||||||
|             ] |             ] | ||||||
|  | |||||||
| @ -1,8 +1,10 @@ | |||||||
| from hashlib import sha256 | from hashlib import sha256 | ||||||
|  |  | ||||||
|  | from django.contrib.auth.signals import user_logged_out | ||||||
| from django.db.models import Model | from django.db.models import Model | ||||||
| from django.db.models.signals import post_delete, post_save, pre_delete | from django.db.models.signals import post_delete, post_save, pre_delete | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
|  | from django.http.request import HttpRequest | ||||||
| from guardian.shortcuts import assign_perm | from guardian.shortcuts import assign_perm | ||||||
|  |  | ||||||
| from authentik.core.models import ( | from authentik.core.models import ( | ||||||
| @ -60,6 +62,31 @@ def ssf_providers_post_save(sender: type[Model], instance: SSFProvider, created: | |||||||
|             instance.save() |             instance.save() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(user_logged_out) | ||||||
|  | def ssf_user_logged_out_session_revoked(sender, request: HttpRequest, user: User, **_): | ||||||
|  |     """Session revoked trigger (user logged out)""" | ||||||
|  |     if not request.session or not request.session.session_key or not user: | ||||||
|  |         return | ||||||
|  |     send_ssf_event( | ||||||
|  |         EventTypes.CAEP_SESSION_REVOKED, | ||||||
|  |         { | ||||||
|  |             "initiating_entity": "user", | ||||||
|  |         }, | ||||||
|  |         sub_id={ | ||||||
|  |             "format": "complex", | ||||||
|  |             "session": { | ||||||
|  |                 "format": "opaque", | ||||||
|  |                 "id": sha256(request.session.session_key.encode("ascii")).hexdigest(), | ||||||
|  |             }, | ||||||
|  |             "user": { | ||||||
|  |                 "format": "email", | ||||||
|  |                 "email": user.email, | ||||||
|  |             }, | ||||||
|  |         }, | ||||||
|  |         request=request, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_delete, sender=AuthenticatedSession) | @receiver(pre_delete, sender=AuthenticatedSession) | ||||||
| def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSession, **_): | def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSession, **_): | ||||||
|     """Session revoked trigger (users' session has been deleted) |     """Session revoked trigger (users' session has been deleted) | ||||||
|  | |||||||
| @ -1,12 +0,0 @@ | |||||||
| """Enterprise app config""" |  | ||||||
|  |  | ||||||
| from authentik.enterprise.apps import EnterpriseConfig |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthentikEnterpriseSearchConfig(EnterpriseConfig): |  | ||||||
|     """Enterprise app config""" |  | ||||||
|  |  | ||||||
|     name = "authentik.enterprise.search" |  | ||||||
|     label = "authentik_search" |  | ||||||
|     verbose_name = "authentik Enterprise.Search" |  | ||||||
|     default = True |  | ||||||
| @ -1,128 +0,0 @@ | |||||||
| """DjangoQL search""" |  | ||||||
|  |  | ||||||
| from collections import OrderedDict, defaultdict |  | ||||||
| from collections.abc import Generator |  | ||||||
|  |  | ||||||
| from django.db import connection |  | ||||||
| from django.db.models import Model, Q |  | ||||||
| from djangoql.compat import text_type |  | ||||||
| from djangoql.schema import StrField |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class JSONSearchField(StrField): |  | ||||||
|     """JSON field for DjangoQL""" |  | ||||||
|  |  | ||||||
|     model: Model |  | ||||||
|  |  | ||||||
|     def __init__(self, model=None, name=None, nullable=None, suggest_nested=True): |  | ||||||
|         # Set this in the constructor to not clobber the type variable |  | ||||||
|         self.type = "relation" |  | ||||||
|         self.suggest_nested = suggest_nested |  | ||||||
|         super().__init__(model, name, nullable) |  | ||||||
|  |  | ||||||
|     def get_lookup(self, path, operator, value): |  | ||||||
|         search = "__".join(path) |  | ||||||
|         op, invert = self.get_operator(operator) |  | ||||||
|         q = Q(**{f"{search}{op}": self.get_lookup_value(value)}) |  | ||||||
|         return ~q if invert else q |  | ||||||
|  |  | ||||||
|     def json_field_keys(self) -> Generator[tuple[str]]: |  | ||||||
|         with connection.cursor() as cursor: |  | ||||||
|             cursor.execute( |  | ||||||
|                 f""" |  | ||||||
|                 WITH RECURSIVE "{self.name}_keys" AS ( |  | ||||||
|                     SELECT |  | ||||||
|                         ARRAY[jsonb_object_keys("{self.name}")] AS key_path_array, |  | ||||||
|                         "{self.name}" -> jsonb_object_keys("{self.name}") AS value |  | ||||||
|                     FROM {self.model._meta.db_table} |  | ||||||
|                     WHERE "{self.name}" IS NOT NULL |  | ||||||
|                         AND jsonb_typeof("{self.name}") = 'object' |  | ||||||
|  |  | ||||||
|                     UNION ALL |  | ||||||
|  |  | ||||||
|                     SELECT |  | ||||||
|                         ck.key_path_array || jsonb_object_keys(ck.value), |  | ||||||
|                         ck.value -> jsonb_object_keys(ck.value) AS value |  | ||||||
|                     FROM "{self.name}_keys" ck |  | ||||||
|                     WHERE jsonb_typeof(ck.value) = 'object' |  | ||||||
|                 ), |  | ||||||
|  |  | ||||||
|                 unique_paths AS ( |  | ||||||
|                     SELECT DISTINCT key_path_array |  | ||||||
|                     FROM "{self.name}_keys" |  | ||||||
|                 ) |  | ||||||
|  |  | ||||||
|                 SELECT key_path_array FROM unique_paths; |  | ||||||
|             """  # nosec |  | ||||||
|             ) |  | ||||||
|             return (x[0] for x in cursor.fetchall()) |  | ||||||
|  |  | ||||||
|     def get_nested_options(self) -> OrderedDict: |  | ||||||
|         """Get keys of all nested objects to show autocomplete""" |  | ||||||
|         if not self.suggest_nested: |  | ||||||
|             return OrderedDict() |  | ||||||
|         base_model_name = f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}" |  | ||||||
|  |  | ||||||
|         def recursive_function(parts: list[str], parent_parts: list[str] | None = None): |  | ||||||
|             if not parent_parts: |  | ||||||
|                 parent_parts = [] |  | ||||||
|             path = parts.pop(0) |  | ||||||
|             parent_parts.append(path) |  | ||||||
|             relation_key = "_".join(parent_parts) |  | ||||||
|             if len(parts) > 1: |  | ||||||
|                 out_dict = { |  | ||||||
|                     relation_key: { |  | ||||||
|                         parts[0]: { |  | ||||||
|                             "type": "relation", |  | ||||||
|                             "relation": f"{relation_key}_{parts[0]}", |  | ||||||
|                         } |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|                 child_paths = recursive_function(parts.copy(), parent_parts.copy()) |  | ||||||
|                 child_paths.update(out_dict) |  | ||||||
|                 return child_paths |  | ||||||
|             else: |  | ||||||
|                 return {relation_key: {parts[0]: {}}} |  | ||||||
|  |  | ||||||
|         relation_structure = defaultdict(dict) |  | ||||||
|  |  | ||||||
|         for relations in self.json_field_keys(): |  | ||||||
|             result = recursive_function([base_model_name] + relations) |  | ||||||
|             for relation_key, value in result.items(): |  | ||||||
|                 for sub_relation_key, sub_value in value.items(): |  | ||||||
|                     if not relation_structure[relation_key].get(sub_relation_key, None): |  | ||||||
|                         relation_structure[relation_key][sub_relation_key] = sub_value |  | ||||||
|                     else: |  | ||||||
|                         relation_structure[relation_key][sub_relation_key].update(sub_value) |  | ||||||
|  |  | ||||||
|         final_dict = defaultdict(dict) |  | ||||||
|  |  | ||||||
|         for key, value in relation_structure.items(): |  | ||||||
|             for sub_key, sub_value in value.items(): |  | ||||||
|                 if not sub_value: |  | ||||||
|                     final_dict[key][sub_key] = { |  | ||||||
|                         "type": "str", |  | ||||||
|                         "nullable": True, |  | ||||||
|                     } |  | ||||||
|                 else: |  | ||||||
|                     final_dict[key][sub_key] = sub_value |  | ||||||
|         return OrderedDict(final_dict) |  | ||||||
|  |  | ||||||
|     def relation(self) -> str: |  | ||||||
|         return f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ChoiceSearchField(StrField): |  | ||||||
|     def __init__(self, model=None, name=None, nullable=None): |  | ||||||
|         super().__init__(model, name, nullable, suggest_options=True) |  | ||||||
|  |  | ||||||
|     def get_options(self, search): |  | ||||||
|         result = [] |  | ||||||
|         choices = self._field_choices() |  | ||||||
|         if choices: |  | ||||||
|             search = search.lower() |  | ||||||
|             for c in choices: |  | ||||||
|                 choice = text_type(c[0]) |  | ||||||
|                 if search in choice.lower(): |  | ||||||
|                     result.append(choice) |  | ||||||
|         return result |  | ||||||
| @ -1,53 +0,0 @@ | |||||||
| from rest_framework.response import Response |  | ||||||
|  |  | ||||||
| from authentik.api.pagination import Pagination |  | ||||||
| from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, QLSearch |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AutocompletePagination(Pagination): |  | ||||||
|  |  | ||||||
|     def paginate_queryset(self, queryset, request, view=None): |  | ||||||
|         self.view = view |  | ||||||
|         return super().paginate_queryset(queryset, request, view) |  | ||||||
|  |  | ||||||
|     def get_autocomplete(self): |  | ||||||
|         schema = QLSearch().get_schema(self.request, self.view) |  | ||||||
|         introspections = {} |  | ||||||
|         if hasattr(self.view, "get_ql_fields"): |  | ||||||
|             from authentik.enterprise.search.schema import AKQLSchemaSerializer |  | ||||||
|  |  | ||||||
|             introspections = AKQLSchemaSerializer().serialize( |  | ||||||
|                 schema(self.page.paginator.object_list.model) |  | ||||||
|             ) |  | ||||||
|         return introspections |  | ||||||
|  |  | ||||||
|     def get_paginated_response(self, data): |  | ||||||
|         previous_page_number = 0 |  | ||||||
|         if self.page.has_previous(): |  | ||||||
|             previous_page_number = self.page.previous_page_number() |  | ||||||
|         next_page_number = 0 |  | ||||||
|         if self.page.has_next(): |  | ||||||
|             next_page_number = self.page.next_page_number() |  | ||||||
|         return Response( |  | ||||||
|             { |  | ||||||
|                 "pagination": { |  | ||||||
|                     "next": next_page_number, |  | ||||||
|                     "previous": previous_page_number, |  | ||||||
|                     "count": self.page.paginator.count, |  | ||||||
|                     "current": self.page.number, |  | ||||||
|                     "total_pages": self.page.paginator.num_pages, |  | ||||||
|                     "start_index": self.page.start_index(), |  | ||||||
|                     "end_index": self.page.end_index(), |  | ||||||
|                 }, |  | ||||||
|                 "results": data, |  | ||||||
|                 "autocomplete": self.get_autocomplete(), |  | ||||||
|             } |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def get_paginated_response_schema(self, schema): |  | ||||||
|         final_schema = super().get_paginated_response_schema(schema) |  | ||||||
|         final_schema["properties"]["autocomplete"] = { |  | ||||||
|             "$ref": f"#/components/schemas/{AUTOCOMPLETE_COMPONENT_NAME}" |  | ||||||
|         } |  | ||||||
|         final_schema["required"].append("autocomplete") |  | ||||||
|         return final_schema |  | ||||||
| @ -1,78 +0,0 @@ | |||||||
| """DjangoQL search""" |  | ||||||
|  |  | ||||||
| from django.apps import apps |  | ||||||
| from django.db.models import QuerySet |  | ||||||
| from djangoql.ast import Name |  | ||||||
| from djangoql.exceptions import DjangoQLError |  | ||||||
| from djangoql.queryset import apply_search |  | ||||||
| from djangoql.schema import DjangoQLSchema |  | ||||||
| from rest_framework.filters import SearchFilter |  | ||||||
| from rest_framework.request import Request |  | ||||||
| from structlog.stdlib import get_logger |  | ||||||
|  |  | ||||||
| from authentik.enterprise.search.fields import JSONSearchField |  | ||||||
|  |  | ||||||
| LOGGER = get_logger() |  | ||||||
| AUTOCOMPLETE_COMPONENT_NAME = "Autocomplete" |  | ||||||
| AUTOCOMPLETE_SCHEMA = { |  | ||||||
|     "type": "object", |  | ||||||
|     "additionalProperties": {}, |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class BaseSchema(DjangoQLSchema): |  | ||||||
|     """Base Schema which deals with JSON Fields""" |  | ||||||
|  |  | ||||||
|     def resolve_name(self, name: Name): |  | ||||||
|         model = self.model_label(self.current_model) |  | ||||||
|         root_field = name.parts[0] |  | ||||||
|         field = self.models[model].get(root_field) |  | ||||||
|         # If the query goes into a JSON field, return the root |  | ||||||
|         # field as the JSON field will do the rest |  | ||||||
|         if isinstance(field, JSONSearchField): |  | ||||||
|             # This is a workaround; build_filter will remove the right-most |  | ||||||
|             # entry in the path as that is intended to be the same as the field |  | ||||||
|             # however for JSON that is not the case |  | ||||||
|             if name.parts[-1] != root_field: |  | ||||||
|                 name.parts.append(root_field) |  | ||||||
|             return field |  | ||||||
|         return super().resolve_name(name) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class QLSearch(SearchFilter): |  | ||||||
|     """rest_framework search filter which uses DjangoQL""" |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def enabled(self): |  | ||||||
|         return apps.get_app_config("authentik_enterprise").enabled() |  | ||||||
|  |  | ||||||
|     def get_search_terms(self, request) -> str: |  | ||||||
|         """ |  | ||||||
|         Search terms are set by a ?search=... query parameter, |  | ||||||
|         and may be comma and/or whitespace delimited. |  | ||||||
|         """ |  | ||||||
|         params = request.query_params.get(self.search_param, "") |  | ||||||
|         params = params.replace("\x00", "")  # strip null characters |  | ||||||
|         return params |  | ||||||
|  |  | ||||||
|     def get_schema(self, request: Request, view) -> BaseSchema: |  | ||||||
|         ql_fields = [] |  | ||||||
|         if hasattr(view, "get_ql_fields"): |  | ||||||
|             ql_fields = view.get_ql_fields() |  | ||||||
|  |  | ||||||
|         class InlineSchema(BaseSchema): |  | ||||||
|             def get_fields(self, model): |  | ||||||
|                 return ql_fields or [] |  | ||||||
|  |  | ||||||
|         return InlineSchema |  | ||||||
|  |  | ||||||
|     def filter_queryset(self, request: Request, queryset: QuerySet, view) -> QuerySet: |  | ||||||
|         search_query = self.get_search_terms(request) |  | ||||||
|         schema = self.get_schema(request, view) |  | ||||||
|         if len(search_query) == 0 or not self.enabled: |  | ||||||
|             return super().filter_queryset(request, queryset, view) |  | ||||||
|         try: |  | ||||||
|             return apply_search(queryset, search_query, schema=schema) |  | ||||||
|         except DjangoQLError as exc: |  | ||||||
|             LOGGER.debug("Failed to parse search expression", exc=exc) |  | ||||||
|             return super().filter_queryset(request, queryset, view) |  | ||||||
| @ -1,29 +0,0 @@ | |||||||
| from djangoql.serializers import DjangoQLSchemaSerializer |  | ||||||
| from drf_spectacular.generators import SchemaGenerator |  | ||||||
|  |  | ||||||
| from authentik.api.schema import create_component |  | ||||||
| from authentik.enterprise.search.fields import JSONSearchField |  | ||||||
| from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AKQLSchemaSerializer(DjangoQLSchemaSerializer): |  | ||||||
|     def serialize(self, schema): |  | ||||||
|         serialization = super().serialize(schema) |  | ||||||
|         for _, fields in schema.models.items(): |  | ||||||
|             for _, field in fields.items(): |  | ||||||
|                 if not isinstance(field, JSONSearchField): |  | ||||||
|                     continue |  | ||||||
|                 serialization["models"].update(field.get_nested_options()) |  | ||||||
|         return serialization |  | ||||||
|  |  | ||||||
|     def serialize_field(self, field): |  | ||||||
|         result = super().serialize_field(field) |  | ||||||
|         if isinstance(field, JSONSearchField): |  | ||||||
|             result["relation"] = field.relation() |  | ||||||
|         return result |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def postprocess_schema_search_autocomplete(result, generator: SchemaGenerator, **kwargs): |  | ||||||
|     create_component(generator, AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA) |  | ||||||
|  |  | ||||||
|     return result |  | ||||||
| @ -1,17 +0,0 @@ | |||||||
| SPECTACULAR_SETTINGS = { |  | ||||||
|     "POSTPROCESSING_HOOKS": [ |  | ||||||
|         "authentik.api.schema.postprocess_schema_responses", |  | ||||||
|         "authentik.enterprise.search.schema.postprocess_schema_search_autocomplete", |  | ||||||
|         "drf_spectacular.hooks.postprocess_schema_enums", |  | ||||||
|     ], |  | ||||||
| } |  | ||||||
|  |  | ||||||
| REST_FRAMEWORK = { |  | ||||||
|     "DEFAULT_PAGINATION_CLASS": "authentik.enterprise.search.pagination.AutocompletePagination", |  | ||||||
|     "DEFAULT_FILTER_BACKENDS": [ |  | ||||||
|         "authentik.enterprise.search.ql.QLSearch", |  | ||||||
|         "authentik.rbac.filters.ObjectFilter", |  | ||||||
|         "django_filters.rest_framework.DjangoFilterBackend", |  | ||||||
|         "rest_framework.filters.OrderingFilter", |  | ||||||
|     ], |  | ||||||
| } |  | ||||||
| @ -1,78 +0,0 @@ | |||||||
| from json import loads |  | ||||||
| from unittest.mock import PropertyMock, patch |  | ||||||
| from urllib.parse import urlencode |  | ||||||
|  |  | ||||||
| from django.urls import reverse |  | ||||||
| from rest_framework.test import APITestCase |  | ||||||
|  |  | ||||||
| from authentik.core.tests.utils import create_test_admin_user |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @patch( |  | ||||||
|     "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", |  | ||||||
|     PropertyMock(return_value=True), |  | ||||||
| ) |  | ||||||
| class QLTest(APITestCase): |  | ||||||
|  |  | ||||||
|     def setUp(self): |  | ||||||
|         self.user = create_test_admin_user() |  | ||||||
|         # ensure we have more than 1 user |  | ||||||
|         create_test_admin_user() |  | ||||||
|  |  | ||||||
|     def test_search(self): |  | ||||||
|         """Test simple search query""" |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|         query = f'username = "{self.user.username}"' |  | ||||||
|         res = self.client.get( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_api:user-list", |  | ||||||
|             ) |  | ||||||
|             + f"?{urlencode({"search": query})}" |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(res.status_code, 200) |  | ||||||
|         content = loads(res.content) |  | ||||||
|         self.assertEqual(content["pagination"]["count"], 1) |  | ||||||
|         self.assertEqual(content["results"][0]["username"], self.user.username) |  | ||||||
|  |  | ||||||
|     def test_no_search(self): |  | ||||||
|         """Ensure works with no search query""" |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|         res = self.client.get( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_api:user-list", |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(res.status_code, 200) |  | ||||||
|         content = loads(res.content) |  | ||||||
|         self.assertNotEqual(content["pagination"]["count"], 1) |  | ||||||
|  |  | ||||||
|     def test_search_no_ql(self): |  | ||||||
|         """Test simple search query (no QL)""" |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|         res = self.client.get( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_api:user-list", |  | ||||||
|             ) |  | ||||||
|             + f"?{urlencode({"search": self.user.username})}" |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(res.status_code, 200) |  | ||||||
|         content = loads(res.content) |  | ||||||
|         self.assertGreaterEqual(content["pagination"]["count"], 1) |  | ||||||
|         self.assertEqual(content["results"][0]["username"], self.user.username) |  | ||||||
|  |  | ||||||
|     def test_search_json(self): |  | ||||||
|         """Test search query with a JSON attribute""" |  | ||||||
|         self.user.attributes = {"foo": {"bar": "baz"}} |  | ||||||
|         self.user.save() |  | ||||||
|         self.client.force_login(self.user) |  | ||||||
|         query = 'attributes.foo.bar = "baz"' |  | ||||||
|         res = self.client.get( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_api:user-list", |  | ||||||
|             ) |  | ||||||
|             + f"?{urlencode({"search": query})}" |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(res.status_code, 200) |  | ||||||
|         content = loads(res.content) |  | ||||||
|         self.assertEqual(content["pagination"]["count"], 1) |  | ||||||
|         self.assertEqual(content["results"][0]["username"], self.user.username) |  | ||||||
| @ -18,7 +18,6 @@ TENANT_APPS = [ | |||||||
|     "authentik.enterprise.providers.google_workspace", |     "authentik.enterprise.providers.google_workspace", | ||||||
|     "authentik.enterprise.providers.microsoft_entra", |     "authentik.enterprise.providers.microsoft_entra", | ||||||
|     "authentik.enterprise.providers.ssf", |     "authentik.enterprise.providers.ssf", | ||||||
|     "authentik.enterprise.search", |  | ||||||
|     "authentik.enterprise.stages.authenticator_endpoint_gdtc", |     "authentik.enterprise.stages.authenticator_endpoint_gdtc", | ||||||
|     "authentik.enterprise.stages.mtls", |     "authentik.enterprise.stages.mtls", | ||||||
|     "authentik.enterprise.stages.source", |     "authentik.enterprise.stages.source", | ||||||
|  | |||||||
| @ -97,7 +97,6 @@ class SourceStageFinal(StageView): | |||||||
|         token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN) |         token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN) | ||||||
|         self.logger.info("Replacing source flow with overridden flow", flow=token.flow.slug) |         self.logger.info("Replacing source flow with overridden flow", flow=token.flow.slug) | ||||||
|         plan = token.plan |         plan = token.plan | ||||||
|         plan.context.update(self.executor.plan.context) |  | ||||||
|         plan.context[PLAN_CONTEXT_IS_RESTORED] = token |         plan.context[PLAN_CONTEXT_IS_RESTORED] = token | ||||||
|         response = plan.to_redirect(self.request, token.flow) |         response = plan.to_redirect(self.request, token.flow) | ||||||
|         token.delete() |         token.delete() | ||||||
|  | |||||||
| @ -90,17 +90,14 @@ class TestSourceStage(FlowTestCase): | |||||||
|         plan: FlowPlan = session[SESSION_KEY_PLAN] |         plan: FlowPlan = session[SESSION_KEY_PLAN] | ||||||
|         plan.insert_stage(in_memory_stage(SourceStageFinal), index=0) |         plan.insert_stage(in_memory_stage(SourceStageFinal), index=0) | ||||||
|         plan.context[PLAN_CONTEXT_IS_RESTORED] = flow_token |         plan.context[PLAN_CONTEXT_IS_RESTORED] = flow_token | ||||||
|         plan.context["foo"] = "bar" |  | ||||||
|         session[SESSION_KEY_PLAN] = plan |         session[SESSION_KEY_PLAN] = plan | ||||||
|         session.save() |         session.save() | ||||||
|  |  | ||||||
|         # Pretend we've just returned from the source |         # Pretend we've just returned from the source | ||||||
|         with self.assertFlowFinishes() as ff: |         response = self.client.get( | ||||||
|             response = self.client.get( |             reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True | ||||||
|                 reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True |         ) | ||||||
|             ) |         self.assertEqual(response.status_code, 200) | ||||||
|             self.assertEqual(response.status_code, 200) |         self.assertStageRedirects( | ||||||
|             self.assertStageRedirects( |             response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug}) | ||||||
|                 response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug}) |         ) | ||||||
|             ) |  | ||||||
|         self.assertEqual(ff().context["foo"], "bar") |  | ||||||
|  | |||||||
| @ -132,22 +132,6 @@ class EventViewSet(ModelViewSet): | |||||||
|     ] |     ] | ||||||
|     filterset_class = EventsFilter |     filterset_class = EventsFilter | ||||||
|  |  | ||||||
|     def get_ql_fields(self): |  | ||||||
|         from djangoql.schema import DateTimeField, StrField |  | ||||||
|  |  | ||||||
|         from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField |  | ||||||
|  |  | ||||||
|         return [ |  | ||||||
|             ChoiceSearchField(Event, "action"), |  | ||||||
|             StrField(Event, "event_uuid"), |  | ||||||
|             StrField(Event, "app", suggest_options=True), |  | ||||||
|             StrField(Event, "client_ip"), |  | ||||||
|             JSONSearchField(Event, "user", suggest_nested=False), |  | ||||||
|             JSONSearchField(Event, "brand", suggest_nested=False), |  | ||||||
|             JSONSearchField(Event, "context", suggest_nested=False), |  | ||||||
|             DateTimeField(Event, "created", suggest_options=True), |  | ||||||
|         ] |  | ||||||
|  |  | ||||||
|     @extend_schema( |     @extend_schema( | ||||||
|         methods=["GET"], |         methods=["GET"], | ||||||
|         responses={200: EventTopPerUserSerializer(many=True)}, |         responses={200: EventTopPerUserSerializer(many=True)}, | ||||||
|  | |||||||
| @ -11,7 +11,7 @@ from authentik.events.models import NotificationRule | |||||||
| class NotificationRuleSerializer(ModelSerializer): | class NotificationRuleSerializer(ModelSerializer): | ||||||
|     """NotificationRule Serializer""" |     """NotificationRule Serializer""" | ||||||
|  |  | ||||||
|     destination_group_obj = GroupSerializer(read_only=True, source="destination_group") |     group_obj = GroupSerializer(read_only=True, source="group") | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         model = NotificationRule |         model = NotificationRule | ||||||
| @ -20,9 +20,8 @@ class NotificationRuleSerializer(ModelSerializer): | |||||||
|             "name", |             "name", | ||||||
|             "transports", |             "transports", | ||||||
|             "severity", |             "severity", | ||||||
|             "destination_group", |             "group", | ||||||
|             "destination_group_obj", |             "group_obj", | ||||||
|             "destination_event_user", |  | ||||||
|         ] |         ] | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -31,6 +30,6 @@ class NotificationRuleViewSet(UsedByMixin, ModelViewSet): | |||||||
|  |  | ||||||
|     queryset = NotificationRule.objects.all() |     queryset = NotificationRule.objects.all() | ||||||
|     serializer_class = NotificationRuleSerializer |     serializer_class = NotificationRuleSerializer | ||||||
|     filterset_fields = ["name", "severity", "destination_group__name"] |     filterset_fields = ["name", "severity", "group__name"] | ||||||
|     ordering = ["name"] |     ordering = ["name"] | ||||||
|     search_fields = ["name", "destination_group__name"] |     search_fields = ["name", "group__name"] | ||||||
|  | |||||||
| @ -15,13 +15,13 @@ class MMDBContextProcessor(EventContextProcessor): | |||||||
|         self.reader: Reader | None = None |         self.reader: Reader | None = None | ||||||
|         self._last_mtime: float = 0.0 |         self._last_mtime: float = 0.0 | ||||||
|         self.logger = get_logger() |         self.logger = get_logger() | ||||||
|         self.load() |         self.open() | ||||||
|  |  | ||||||
|     def path(self) -> str | None: |     def path(self) -> str | None: | ||||||
|         """Get the path to the MMDB file to load""" |         """Get the path to the MMDB file to load""" | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def load(self): |     def open(self): | ||||||
|         """Get GeoIP Reader, if configured, otherwise none""" |         """Get GeoIP Reader, if configured, otherwise none""" | ||||||
|         path = self.path() |         path = self.path() | ||||||
|         if path == "" or not path: |         if path == "" or not path: | ||||||
| @ -44,7 +44,7 @@ class MMDBContextProcessor(EventContextProcessor): | |||||||
|             diff = self._last_mtime < mtime |             diff = self._last_mtime < mtime | ||||||
|             if diff > 0: |             if diff > 0: | ||||||
|                 self.logger.info("Found new MMDB Database, reopening", diff=diff, path=path) |                 self.logger.info("Found new MMDB Database, reopening", diff=diff, path=path) | ||||||
|                 self.load() |                 self.open() | ||||||
|         except OSError as exc: |         except OSError as exc: | ||||||
|             self.logger.warning("Failed to check MMDB age", exc=exc) |             self.logger.warning("Failed to check MMDB age", exc=exc) | ||||||
|  |  | ||||||
|  | |||||||
| @ -19,7 +19,7 @@ from authentik.blueprints.v1.importer import excluded_models | |||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.events.models import Event, EventAction, Notification | from authentik.events.models import Event, EventAction, Notification | ||||||
| from authentik.events.utils import model_to_dict | from authentik.events.utils import model_to_dict | ||||||
| from authentik.lib.sentry import should_ignore_exception | from authentik.lib.sentry import before_send | ||||||
| from authentik.lib.utils.errors import exception_to_string | from authentik.lib.utils.errors import exception_to_string | ||||||
| from authentik.stages.authenticator_static.models import StaticToken | from authentik.stages.authenticator_static.models import StaticToken | ||||||
|  |  | ||||||
| @ -173,7 +173,7 @@ class AuditMiddleware: | |||||||
|                 message=exception_to_string(exception), |                 message=exception_to_string(exception), | ||||||
|             ) |             ) | ||||||
|             thread.run() |             thread.run() | ||||||
|         elif not should_ignore_exception(exception): |         elif before_send({}, {"exc_info": (None, exception, None)}) is not None: | ||||||
|             thread = EventNewThread( |             thread = EventNewThread( | ||||||
|                 EventAction.SYSTEM_EXCEPTION, |                 EventAction.SYSTEM_EXCEPTION, | ||||||
|                 request, |                 request, | ||||||
|  | |||||||
| @ -1,26 +0,0 @@ | |||||||
| # Generated by Django 5.1.11 on 2025-06-16 23:21 |  | ||||||
|  |  | ||||||
| from django.db import migrations, models |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ("authentik_events", "0009_remove_notificationtransport_webhook_mapping_and_more"), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.RenameField( |  | ||||||
|             model_name="notificationrule", |  | ||||||
|             old_name="group", |  | ||||||
|             new_name="destination_group", |  | ||||||
|         ), |  | ||||||
|         migrations.AddField( |  | ||||||
|             model_name="notificationrule", |  | ||||||
|             name="destination_event_user", |  | ||||||
|             field=models.BooleanField( |  | ||||||
|                 default=False, |  | ||||||
|                 help_text="When enabled, notification will be sent to user the user that triggered the event.When destination_group is configured, notification is sent to both.", |  | ||||||
|             ), |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -1,12 +1,10 @@ | |||||||
| """authentik events models""" | """authentik events models""" | ||||||
|  |  | ||||||
| from collections.abc import Generator |  | ||||||
| from datetime import timedelta | from datetime import timedelta | ||||||
| from difflib import get_close_matches | from difflib import get_close_matches | ||||||
| from functools import lru_cache | from functools import lru_cache | ||||||
| from inspect import currentframe | from inspect import currentframe | ||||||
| from smtplib import SMTPException | from smtplib import SMTPException | ||||||
| from typing import Any |  | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
|  |  | ||||||
| from django.apps import apps | from django.apps import apps | ||||||
| @ -193,32 +191,17 @@ class Event(SerializerModel, ExpiringModel): | |||||||
|             brand: Brand = request.brand |             brand: Brand = request.brand | ||||||
|             self.brand = sanitize_dict(model_to_dict(brand)) |             self.brand = sanitize_dict(model_to_dict(brand)) | ||||||
|         if hasattr(request, "user"): |         if hasattr(request, "user"): | ||||||
|             self.user = get_user(request.user) |             original_user = None | ||||||
|  |             if hasattr(request, "session"): | ||||||
|  |                 original_user = request.session.get(SESSION_KEY_IMPERSONATE_ORIGINAL_USER, None) | ||||||
|  |             self.user = get_user(request.user, original_user) | ||||||
|         if user: |         if user: | ||||||
|             self.user = get_user(user) |             self.user = get_user(user) | ||||||
|  |         # Check if we're currently impersonating, and add that user | ||||||
|         if hasattr(request, "session"): |         if hasattr(request, "session"): | ||||||
|             from authentik.flows.views.executor import SESSION_KEY_PLAN |  | ||||||
|  |  | ||||||
|             # Check if we're currently impersonating, and add that user |  | ||||||
|             if SESSION_KEY_IMPERSONATE_ORIGINAL_USER in request.session: |             if SESSION_KEY_IMPERSONATE_ORIGINAL_USER in request.session: | ||||||
|                 self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER]) |                 self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER]) | ||||||
|                 self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER]) |                 self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER]) | ||||||
|             # Special case for events that happen during a flow, the user might not be authenticated |  | ||||||
|             # yet but is a pending user instead |  | ||||||
|             if SESSION_KEY_PLAN in request.session: |  | ||||||
|                 from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan |  | ||||||
|  |  | ||||||
|                 plan: FlowPlan = request.session[SESSION_KEY_PLAN] |  | ||||||
|                 pending_user = plan.context.get(PLAN_CONTEXT_PENDING_USER, None) |  | ||||||
|                 # Only save `authenticated_as` if there's a different pending user in the flow |  | ||||||
|                 # than the user that is authenticated |  | ||||||
|                 if pending_user and ( |  | ||||||
|                     (pending_user.pk and pending_user.pk != self.user.get("pk")) |  | ||||||
|                     or (not pending_user.pk) |  | ||||||
|                 ): |  | ||||||
|                     orig_user = self.user.copy() |  | ||||||
|  |  | ||||||
|                     self.user = {"authenticated_as": orig_user, **get_user(pending_user)} |  | ||||||
|         # User 255.255.255.255 as fallback if IP cannot be determined |         # User 255.255.255.255 as fallback if IP cannot be determined | ||||||
|         self.client_ip = ClientIPMiddleware.get_client_ip(request) |         self.client_ip = ClientIPMiddleware.get_client_ip(request) | ||||||
|         # Enrich event data |         # Enrich event data | ||||||
| @ -564,7 +547,7 @@ class NotificationRule(SerializerModel, PolicyBindingModel): | |||||||
|         default=NotificationSeverity.NOTICE, |         default=NotificationSeverity.NOTICE, | ||||||
|         help_text=_("Controls which severity level the created notifications will have."), |         help_text=_("Controls which severity level the created notifications will have."), | ||||||
|     ) |     ) | ||||||
|     destination_group = models.ForeignKey( |     group = models.ForeignKey( | ||||||
|         Group, |         Group, | ||||||
|         help_text=_( |         help_text=_( | ||||||
|             "Define which group of users this notification should be sent and shown to. " |             "Define which group of users this notification should be sent and shown to. " | ||||||
| @ -574,19 +557,6 @@ class NotificationRule(SerializerModel, PolicyBindingModel): | |||||||
|         blank=True, |         blank=True, | ||||||
|         on_delete=models.SET_NULL, |         on_delete=models.SET_NULL, | ||||||
|     ) |     ) | ||||||
|     destination_event_user = models.BooleanField( |  | ||||||
|         default=False, |  | ||||||
|         help_text=_( |  | ||||||
|             "When enabled, notification will be sent to user the user that triggered the event." |  | ||||||
|             "When destination_group is configured, notification is sent to both." |  | ||||||
|         ), |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     def destination_users(self, event: Event) -> Generator[User, Any]: |  | ||||||
|         if self.destination_event_user and event.user.get("pk"): |  | ||||||
|             yield User(pk=event.user.get("pk")) |  | ||||||
|         if self.destination_group: |  | ||||||
|             yield from self.destination_group.users.all() |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> type[Serializer]: |     def serializer(self) -> type[Serializer]: | ||||||
|  | |||||||
| @ -68,10 +68,14 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): | |||||||
|     if not result.passing: |     if not result.passing: | ||||||
|         return |         return | ||||||
|  |  | ||||||
|  |     if not trigger.group: | ||||||
|  |         LOGGER.debug("e(trigger): trigger has no group", trigger=trigger) | ||||||
|  |         return | ||||||
|  |  | ||||||
|     LOGGER.debug("e(trigger): event trigger matched", trigger=trigger) |     LOGGER.debug("e(trigger): event trigger matched", trigger=trigger) | ||||||
|     # Create the notification objects |     # Create the notification objects | ||||||
|     for transport in trigger.transports.all(): |     for transport in trigger.transports.all(): | ||||||
|         for user in trigger.destination_users(event): |         for user in trigger.group.users.all(): | ||||||
|             LOGGER.debug("created notification") |             LOGGER.debug("created notification") | ||||||
|             notification_transport.apply_async( |             notification_transport.apply_async( | ||||||
|                 args=[ |                 args=[ | ||||||
|  | |||||||
| @ -2,9 +2,7 @@ | |||||||
|  |  | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
|  |  | ||||||
| from authentik.events.context_processors.base import get_context_processors |  | ||||||
| from authentik.events.context_processors.geoip import GeoIPContextProcessor | from authentik.events.context_processors.geoip import GeoIPContextProcessor | ||||||
| from authentik.events.models import Event, EventAction |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestGeoIP(TestCase): | class TestGeoIP(TestCase): | ||||||
| @ -15,7 +13,8 @@ class TestGeoIP(TestCase): | |||||||
|  |  | ||||||
|     def test_simple(self): |     def test_simple(self): | ||||||
|         """Test simple city wrapper""" |         """Test simple city wrapper""" | ||||||
|         # IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json |         # IPs from | ||||||
|  |         # https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             self.reader.city_dict("2.125.160.216"), |             self.reader.city_dict("2.125.160.216"), | ||||||
|             { |             { | ||||||
| @ -26,12 +25,3 @@ class TestGeoIP(TestCase): | |||||||
|                 "long": -1.25, |                 "long": -1.25, | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def test_special_chars(self): |  | ||||||
|         """Test city name with special characters""" |  | ||||||
|         # IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json |  | ||||||
|         event = Event.new(EventAction.LOGIN) |  | ||||||
|         event.client_ip = "89.160.20.112" |  | ||||||
|         for processor in get_context_processors(): |  | ||||||
|             processor.enrich_event(event) |  | ||||||
|         event.save() |  | ||||||
|  | |||||||
| @ -8,11 +8,9 @@ from django.views.debug import SafeExceptionReporterFilter | |||||||
| from guardian.shortcuts import get_anonymous_user | from guardian.shortcuts import get_anonymous_user | ||||||
|  |  | ||||||
| from authentik.brands.models import Brand | from authentik.brands.models import Brand | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group | ||||||
| from authentik.core.tests.utils import create_test_user |  | ||||||
| from authentik.events.models import Event | from authentik.events.models import Event | ||||||
| from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan | from authentik.flows.views.executor import QS_QUERY | ||||||
| from authentik.flows.views.executor import QS_QUERY, SESSION_KEY_PLAN |  | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.policies.dummy.models import DummyPolicy | from authentik.policies.dummy.models import DummyPolicy | ||||||
|  |  | ||||||
| @ -118,92 +116,3 @@ class TestEvents(TestCase): | |||||||
|                 "pk": brand.pk.hex, |                 "pk": brand.pk.hex, | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def test_from_http_flow_pending_user(self): |  | ||||||
|         """Test request from flow request with a pending user""" |  | ||||||
|         user = create_test_user() |  | ||||||
|  |  | ||||||
|         session = self.client.session |  | ||||||
|         plan = FlowPlan(generate_id()) |  | ||||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = user |  | ||||||
|         session[SESSION_KEY_PLAN] = plan |  | ||||||
|         session.save() |  | ||||||
|  |  | ||||||
|         request = self.factory.get("/") |  | ||||||
|         request.session = session |  | ||||||
|         request.user = user |  | ||||||
|  |  | ||||||
|         event = Event.new("unittest").from_http(request) |  | ||||||
|         self.assertEqual( |  | ||||||
|             event.user, |  | ||||||
|             { |  | ||||||
|                 "email": user.email, |  | ||||||
|                 "pk": user.pk, |  | ||||||
|                 "username": user.username, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_from_http_flow_pending_user_anon(self): |  | ||||||
|         """Test request from flow request with a pending user""" |  | ||||||
|         user = create_test_user() |  | ||||||
|         anon = get_anonymous_user() |  | ||||||
|  |  | ||||||
|         session = self.client.session |  | ||||||
|         plan = FlowPlan(generate_id()) |  | ||||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = user |  | ||||||
|         session[SESSION_KEY_PLAN] = plan |  | ||||||
|         session.save() |  | ||||||
|  |  | ||||||
|         request = self.factory.get("/") |  | ||||||
|         request.session = session |  | ||||||
|         request.user = anon |  | ||||||
|  |  | ||||||
|         event = Event.new("unittest").from_http(request) |  | ||||||
|         self.assertEqual( |  | ||||||
|             event.user, |  | ||||||
|             { |  | ||||||
|                 "authenticated_as": { |  | ||||||
|                     "pk": anon.pk, |  | ||||||
|                     "is_anonymous": True, |  | ||||||
|                     "username": "AnonymousUser", |  | ||||||
|                     "email": "", |  | ||||||
|                 }, |  | ||||||
|                 "email": user.email, |  | ||||||
|                 "pk": user.pk, |  | ||||||
|                 "username": user.username, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_from_http_flow_pending_user_fake(self): |  | ||||||
|         """Test request from flow request with a pending user""" |  | ||||||
|         user = User( |  | ||||||
|             username=generate_id(), |  | ||||||
|             email=generate_id(), |  | ||||||
|         ) |  | ||||||
|         anon = get_anonymous_user() |  | ||||||
|  |  | ||||||
|         session = self.client.session |  | ||||||
|         plan = FlowPlan(generate_id()) |  | ||||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = user |  | ||||||
|         session[SESSION_KEY_PLAN] = plan |  | ||||||
|         session.save() |  | ||||||
|  |  | ||||||
|         request = self.factory.get("/") |  | ||||||
|         request.session = session |  | ||||||
|         request.user = anon |  | ||||||
|  |  | ||||||
|         event = Event.new("unittest").from_http(request) |  | ||||||
|         self.assertEqual( |  | ||||||
|             event.user, |  | ||||||
|             { |  | ||||||
|                 "authenticated_as": { |  | ||||||
|                     "pk": anon.pk, |  | ||||||
|                     "is_anonymous": True, |  | ||||||
|                     "username": "AnonymousUser", |  | ||||||
|                     "email": "", |  | ||||||
|                 }, |  | ||||||
|                 "email": user.email, |  | ||||||
|                 "pk": user.pk, |  | ||||||
|                 "username": user.username, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|  | |||||||
| @ -6,7 +6,6 @@ from django.urls import reverse | |||||||
| from rest_framework.test import APITestCase | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.core.tests.utils import create_test_user |  | ||||||
| from authentik.events.models import ( | from authentik.events.models import ( | ||||||
|     Event, |     Event, | ||||||
|     EventAction, |     EventAction, | ||||||
| @ -35,7 +34,7 @@ class TestEventsNotifications(APITestCase): | |||||||
|     def test_trigger_empty(self): |     def test_trigger_empty(self): | ||||||
|         """Test trigger without any policies attached""" |         """Test trigger without any policies attached""" | ||||||
|         transport = NotificationTransport.objects.create(name=generate_id()) |         transport = NotificationTransport.objects.create(name=generate_id()) | ||||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) |         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||||
|         trigger.transports.add(transport) |         trigger.transports.add(transport) | ||||||
|         trigger.save() |         trigger.save() | ||||||
|  |  | ||||||
| @ -47,7 +46,7 @@ class TestEventsNotifications(APITestCase): | |||||||
|     def test_trigger_single(self): |     def test_trigger_single(self): | ||||||
|         """Test simple transport triggering""" |         """Test simple transport triggering""" | ||||||
|         transport = NotificationTransport.objects.create(name=generate_id()) |         transport = NotificationTransport.objects.create(name=generate_id()) | ||||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) |         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||||
|         trigger.transports.add(transport) |         trigger.transports.add(transport) | ||||||
|         trigger.save() |         trigger.save() | ||||||
|         matcher = EventMatcherPolicy.objects.create( |         matcher = EventMatcherPolicy.objects.create( | ||||||
| @ -60,25 +59,6 @@ class TestEventsNotifications(APITestCase): | |||||||
|             Event.new(EventAction.CUSTOM_PREFIX).save() |             Event.new(EventAction.CUSTOM_PREFIX).save() | ||||||
|         self.assertEqual(execute_mock.call_count, 1) |         self.assertEqual(execute_mock.call_count, 1) | ||||||
|  |  | ||||||
|     def test_trigger_event_user(self): |  | ||||||
|         """Test trigger with event user""" |  | ||||||
|         user = create_test_user() |  | ||||||
|         transport = NotificationTransport.objects.create(name=generate_id()) |  | ||||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_event_user=True) |  | ||||||
|         trigger.transports.add(transport) |  | ||||||
|         trigger.save() |  | ||||||
|         matcher = EventMatcherPolicy.objects.create( |  | ||||||
|             name="matcher", action=EventAction.CUSTOM_PREFIX |  | ||||||
|         ) |  | ||||||
|         PolicyBinding.objects.create(target=trigger, policy=matcher, order=0) |  | ||||||
|  |  | ||||||
|         execute_mock = MagicMock() |  | ||||||
|         with patch("authentik.events.models.NotificationTransport.send", execute_mock): |  | ||||||
|             Event.new(EventAction.CUSTOM_PREFIX).set_user(user).save() |  | ||||||
|         self.assertEqual(execute_mock.call_count, 1) |  | ||||||
|         notification: Notification = execute_mock.call_args[0][0] |  | ||||||
|         self.assertEqual(notification.user, user) |  | ||||||
|  |  | ||||||
|     def test_trigger_no_group(self): |     def test_trigger_no_group(self): | ||||||
|         """Test trigger without group""" |         """Test trigger without group""" | ||||||
|         trigger = NotificationRule.objects.create(name=generate_id()) |         trigger = NotificationRule.objects.create(name=generate_id()) | ||||||
| @ -96,7 +76,7 @@ class TestEventsNotifications(APITestCase): | |||||||
|         """Test Policy error which would cause recursion""" |         """Test Policy error which would cause recursion""" | ||||||
|         transport = NotificationTransport.objects.create(name=generate_id()) |         transport = NotificationTransport.objects.create(name=generate_id()) | ||||||
|         NotificationRule.objects.filter(name__startswith="default").delete() |         NotificationRule.objects.filter(name__startswith="default").delete() | ||||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) |         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||||
|         trigger.transports.add(transport) |         trigger.transports.add(transport) | ||||||
|         trigger.save() |         trigger.save() | ||||||
|         matcher = EventMatcherPolicy.objects.create( |         matcher = EventMatcherPolicy.objects.create( | ||||||
| @ -119,7 +99,7 @@ class TestEventsNotifications(APITestCase): | |||||||
|  |  | ||||||
|         transport = NotificationTransport.objects.create(name=generate_id(), send_once=True) |         transport = NotificationTransport.objects.create(name=generate_id(), send_once=True) | ||||||
|         NotificationRule.objects.filter(name__startswith="default").delete() |         NotificationRule.objects.filter(name__startswith="default").delete() | ||||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) |         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||||
|         trigger.transports.add(transport) |         trigger.transports.add(transport) | ||||||
|         trigger.save() |         trigger.save() | ||||||
|         matcher = EventMatcherPolicy.objects.create( |         matcher = EventMatcherPolicy.objects.create( | ||||||
| @ -143,7 +123,7 @@ class TestEventsNotifications(APITestCase): | |||||||
|             name=generate_id(), webhook_mapping_body=mapping, mode=TransportMode.LOCAL |             name=generate_id(), webhook_mapping_body=mapping, mode=TransportMode.LOCAL | ||||||
|         ) |         ) | ||||||
|         NotificationRule.objects.filter(name__startswith="default").delete() |         NotificationRule.objects.filter(name__startswith="default").delete() | ||||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) |         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||||
|         trigger.transports.add(transport) |         trigger.transports.add(transport) | ||||||
|         matcher = EventMatcherPolicy.objects.create( |         matcher = EventMatcherPolicy.objects.create( | ||||||
|             name="matcher", action=EventAction.CUSTOM_PREFIX |             name="matcher", action=EventAction.CUSTOM_PREFIX | ||||||
|  | |||||||
| @ -74,8 +74,8 @@ def model_to_dict(model: Model) -> dict[str, Any]: | |||||||
|     } |     } | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_user(user: User | AnonymousUser) -> dict[str, Any]: | def get_user(user: User | AnonymousUser, original_user: User | None = None) -> dict[str, Any]: | ||||||
|     """Convert user object to dictionary""" |     """Convert user object to dictionary, optionally including the original user""" | ||||||
|     if isinstance(user, AnonymousUser): |     if isinstance(user, AnonymousUser): | ||||||
|         try: |         try: | ||||||
|             user = get_anonymous_user() |             user = get_anonymous_user() | ||||||
| @ -88,6 +88,10 @@ def get_user(user: User | AnonymousUser) -> dict[str, Any]: | |||||||
|     } |     } | ||||||
|     if user.username == settings.ANONYMOUS_USER_NAME: |     if user.username == settings.ANONYMOUS_USER_NAME: | ||||||
|         user_data["is_anonymous"] = True |         user_data["is_anonymous"] = True | ||||||
|  |     if original_user: | ||||||
|  |         original_data = get_user(original_user) | ||||||
|  |         original_data["on_behalf_of"] = user_data | ||||||
|  |         return original_data | ||||||
|     return user_data |     return user_data | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -4,10 +4,8 @@ from unittest.mock import MagicMock, PropertyMock, patch | |||||||
| from urllib.parse import urlencode | from urllib.parse import urlencode | ||||||
|  |  | ||||||
| from django.http import HttpRequest, HttpResponse | from django.http import HttpRequest, HttpResponse | ||||||
| from django.test import override_settings |  | ||||||
| from django.test.client import RequestFactory | from django.test.client import RequestFactory | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from rest_framework.exceptions import ParseError |  | ||||||
|  |  | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.core.tests.utils import create_test_flow, create_test_user | from authentik.core.tests.utils import create_test_flow, create_test_user | ||||||
| @ -650,25 +648,3 @@ class TestFlowExecutor(FlowTestCase): | |||||||
|             self.assertStageResponse(response, flow, component="ak-stage-identification") |             self.assertStageResponse(response, flow, component="ak-stage-identification") | ||||||
|             response = self.client.post(exec_url, {"uid_field": user_other.username}, follow=True) |             response = self.client.post(exec_url, {"uid_field": user_other.username}, follow=True) | ||||||
|             self.assertStageResponse(response, flow, component="ak-stage-access-denied") |             self.assertStageResponse(response, flow, component="ak-stage-access-denied") | ||||||
|  |  | ||||||
|     @patch( |  | ||||||
|         "authentik.flows.views.executor.to_stage_response", |  | ||||||
|         TO_STAGE_RESPONSE_MOCK, |  | ||||||
|     ) |  | ||||||
|     def test_invalid_json(self): |  | ||||||
|         """Test invalid JSON body""" |  | ||||||
|         flow = create_test_flow() |  | ||||||
|         FlowStageBinding.objects.create( |  | ||||||
|             target=flow, stage=DummyStage.objects.create(name=generate_id()), order=0 |  | ||||||
|         ) |  | ||||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) |  | ||||||
|  |  | ||||||
|         with override_settings(TEST=False, DEBUG=False): |  | ||||||
|             self.client.logout() |  | ||||||
|             response = self.client.post(url, data="{", content_type="application/json") |  | ||||||
|             self.assertEqual(response.status_code, 200) |  | ||||||
|  |  | ||||||
|         with self.assertRaises(ParseError): |  | ||||||
|             self.client.logout() |  | ||||||
|             response = self.client.post(url, data="{", content_type="application/json") |  | ||||||
|             self.assertEqual(response.status_code, 200) |  | ||||||
|  | |||||||
| @ -55,7 +55,7 @@ from authentik.flows.planner import ( | |||||||
|     FlowPlanner, |     FlowPlanner, | ||||||
| ) | ) | ||||||
| from authentik.flows.stage import AccessDeniedStage, StageView | from authentik.flows.stage import AccessDeniedStage, StageView | ||||||
| from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception | from authentik.lib.sentry import SentryIgnoredException | ||||||
| from authentik.lib.utils.errors import exception_to_string | from authentik.lib.utils.errors import exception_to_string | ||||||
| from authentik.lib.utils.reflection import all_subclasses, class_to_path | from authentik.lib.utils.reflection import all_subclasses, class_to_path | ||||||
| from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs | from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs | ||||||
| @ -69,6 +69,7 @@ SESSION_KEY_APPLICATION_PRE = "authentik/flows/application_pre" | |||||||
| SESSION_KEY_GET = "authentik/flows/get" | SESSION_KEY_GET = "authentik/flows/get" | ||||||
| SESSION_KEY_POST = "authentik/flows/post" | SESSION_KEY_POST = "authentik/flows/post" | ||||||
| SESSION_KEY_HISTORY = "authentik/flows/history" | SESSION_KEY_HISTORY = "authentik/flows/history" | ||||||
|  | SESSION_KEY_AUTH_STARTED = "authentik/flows/auth_started" | ||||||
| QS_KEY_TOKEN = "flow_token"  # nosec | QS_KEY_TOKEN = "flow_token"  # nosec | ||||||
| QS_QUERY = "query" | QS_QUERY = "query" | ||||||
|  |  | ||||||
| @ -234,13 +235,12 @@ class FlowExecutorView(APIView): | |||||||
|         """Handle exception in stage execution""" |         """Handle exception in stage execution""" | ||||||
|         if settings.DEBUG or settings.TEST: |         if settings.DEBUG or settings.TEST: | ||||||
|             raise exc |             raise exc | ||||||
|  |         capture_exception(exc) | ||||||
|         self._logger.warning(exc) |         self._logger.warning(exc) | ||||||
|         if not should_ignore_exception(exc): |         Event.new( | ||||||
|             capture_exception(exc) |             action=EventAction.SYSTEM_EXCEPTION, | ||||||
|             Event.new( |             message=exception_to_string(exc), | ||||||
|                 action=EventAction.SYSTEM_EXCEPTION, |         ).from_http(self.request) | ||||||
|                 message=exception_to_string(exc), |  | ||||||
|             ).from_http(self.request) |  | ||||||
|         challenge = FlowErrorChallenge(self.request, exc) |         challenge = FlowErrorChallenge(self.request, exc) | ||||||
|         challenge.is_valid(raise_exception=True) |         challenge.is_valid(raise_exception=True) | ||||||
|         return to_stage_response(self.request, HttpChallengeResponse(challenge)) |         return to_stage_response(self.request, HttpChallengeResponse(challenge)) | ||||||
| @ -455,6 +455,7 @@ class FlowExecutorView(APIView): | |||||||
|             SESSION_KEY_APPLICATION_PRE, |             SESSION_KEY_APPLICATION_PRE, | ||||||
|             SESSION_KEY_PLAN, |             SESSION_KEY_PLAN, | ||||||
|             SESSION_KEY_GET, |             SESSION_KEY_GET, | ||||||
|  |             SESSION_KEY_AUTH_STARTED, | ||||||
|             # We might need the initial POST payloads for later requests |             # We might need the initial POST payloads for later requests | ||||||
|             # SESSION_KEY_POST, |             # SESSION_KEY_POST, | ||||||
|             # We don't delete the history on purpose, as a user might |             # We don't delete the history on purpose, as a user might | ||||||
|  | |||||||
| @ -6,7 +6,8 @@ from django.shortcuts import get_object_or_404 | |||||||
| from ua_parser.user_agent_parser import Parse | from ua_parser.user_agent_parser import Parse | ||||||
|  |  | ||||||
| from authentik.core.views.interface import InterfaceView | from authentik.core.views.interface import InterfaceView | ||||||
| from authentik.flows.models import Flow | from authentik.flows.models import Flow, FlowDesignation | ||||||
|  | from authentik.flows.views.executor import SESSION_KEY_AUTH_STARTED | ||||||
|  |  | ||||||
|  |  | ||||||
| class FlowInterfaceView(InterfaceView): | class FlowInterfaceView(InterfaceView): | ||||||
| @ -14,6 +15,12 @@ class FlowInterfaceView(InterfaceView): | |||||||
|  |  | ||||||
|     def get_context_data(self, **kwargs: Any) -> dict[str, Any]: |     def get_context_data(self, **kwargs: Any) -> dict[str, Any]: | ||||||
|         flow = get_object_or_404(Flow, slug=self.kwargs.get("flow_slug")) |         flow = get_object_or_404(Flow, slug=self.kwargs.get("flow_slug")) | ||||||
|  |         if ( | ||||||
|  |             not self.request.user.is_authenticated | ||||||
|  |             and flow.designation == FlowDesignation.AUTHENTICATION | ||||||
|  |         ): | ||||||
|  |             self.request.session[SESSION_KEY_AUTH_STARTED] = True | ||||||
|  |             self.request.session.save() | ||||||
|         kwargs["flow"] = flow |         kwargs["flow"] = flow | ||||||
|         kwargs["flow_background_url"] = flow.background_url(self.request) |         kwargs["flow_background_url"] = flow.background_url(self.request) | ||||||
|         kwargs["inspector"] = "inspector" in self.request.GET |         kwargs["inspector"] = "inspector" in self.request.GET | ||||||
|  | |||||||
| @ -14,7 +14,6 @@ from django_redis.exceptions import ConnectionInterrupted | |||||||
| from docker.errors import DockerException | from docker.errors import DockerException | ||||||
| from h11 import LocalProtocolError | from h11 import LocalProtocolError | ||||||
| from ldap3.core.exceptions import LDAPException | from ldap3.core.exceptions import LDAPException | ||||||
| from psycopg.errors import Error |  | ||||||
| from redis.exceptions import ConnectionError as RedisConnectionError | from redis.exceptions import ConnectionError as RedisConnectionError | ||||||
| from redis.exceptions import RedisError, ResponseError | from redis.exceptions import RedisError, ResponseError | ||||||
| from rest_framework.exceptions import APIException | from rest_framework.exceptions import APIException | ||||||
| @ -45,49 +44,6 @@ class SentryIgnoredException(Exception): | |||||||
|     """Base Class for all errors that are suppressed, and not sent to sentry.""" |     """Base Class for all errors that are suppressed, and not sent to sentry.""" | ||||||
|  |  | ||||||
|  |  | ||||||
| ignored_classes = ( |  | ||||||
|     # Inbuilt types |  | ||||||
|     KeyboardInterrupt, |  | ||||||
|     ConnectionResetError, |  | ||||||
|     OSError, |  | ||||||
|     PermissionError, |  | ||||||
|     # Django Errors |  | ||||||
|     Error, |  | ||||||
|     ImproperlyConfigured, |  | ||||||
|     DatabaseError, |  | ||||||
|     OperationalError, |  | ||||||
|     InternalError, |  | ||||||
|     ProgrammingError, |  | ||||||
|     SuspiciousOperation, |  | ||||||
|     ValidationError, |  | ||||||
|     # Redis errors |  | ||||||
|     RedisConnectionError, |  | ||||||
|     ConnectionInterrupted, |  | ||||||
|     RedisError, |  | ||||||
|     ResponseError, |  | ||||||
|     # websocket errors |  | ||||||
|     ChannelFull, |  | ||||||
|     WebSocketException, |  | ||||||
|     LocalProtocolError, |  | ||||||
|     # rest_framework error |  | ||||||
|     APIException, |  | ||||||
|     # celery errors |  | ||||||
|     WorkerLostError, |  | ||||||
|     CeleryError, |  | ||||||
|     SoftTimeLimitExceeded, |  | ||||||
|     # custom baseclass |  | ||||||
|     SentryIgnoredException, |  | ||||||
|     # ldap errors |  | ||||||
|     LDAPException, |  | ||||||
|     # Docker errors |  | ||||||
|     DockerException, |  | ||||||
|     # End-user errors |  | ||||||
|     Http404, |  | ||||||
|     # AsyncIO |  | ||||||
|     CancelledError, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SentryTransport(HttpTransport): | class SentryTransport(HttpTransport): | ||||||
|     """Custom sentry transport with custom user-agent""" |     """Custom sentry transport with custom user-agent""" | ||||||
|  |  | ||||||
| @ -145,17 +101,56 @@ def traces_sampler(sampling_context: dict) -> float: | |||||||
|     return float(CONFIG.get("error_reporting.sample_rate", 0.1)) |     return float(CONFIG.get("error_reporting.sample_rate", 0.1)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def should_ignore_exception(exc: Exception) -> bool: |  | ||||||
|     """Check if an exception should be dropped""" |  | ||||||
|     return isinstance(exc, ignored_classes) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def before_send(event: dict, hint: dict) -> dict | None: | def before_send(event: dict, hint: dict) -> dict | None: | ||||||
|     """Check if error is database error, and ignore if so""" |     """Check if error is database error, and ignore if so""" | ||||||
|  |  | ||||||
|  |     from psycopg.errors import Error | ||||||
|  |  | ||||||
|  |     ignored_classes = ( | ||||||
|  |         # Inbuilt types | ||||||
|  |         KeyboardInterrupt, | ||||||
|  |         ConnectionResetError, | ||||||
|  |         OSError, | ||||||
|  |         PermissionError, | ||||||
|  |         # Django Errors | ||||||
|  |         Error, | ||||||
|  |         ImproperlyConfigured, | ||||||
|  |         DatabaseError, | ||||||
|  |         OperationalError, | ||||||
|  |         InternalError, | ||||||
|  |         ProgrammingError, | ||||||
|  |         SuspiciousOperation, | ||||||
|  |         ValidationError, | ||||||
|  |         # Redis errors | ||||||
|  |         RedisConnectionError, | ||||||
|  |         ConnectionInterrupted, | ||||||
|  |         RedisError, | ||||||
|  |         ResponseError, | ||||||
|  |         # websocket errors | ||||||
|  |         ChannelFull, | ||||||
|  |         WebSocketException, | ||||||
|  |         LocalProtocolError, | ||||||
|  |         # rest_framework error | ||||||
|  |         APIException, | ||||||
|  |         # celery errors | ||||||
|  |         WorkerLostError, | ||||||
|  |         CeleryError, | ||||||
|  |         SoftTimeLimitExceeded, | ||||||
|  |         # custom baseclass | ||||||
|  |         SentryIgnoredException, | ||||||
|  |         # ldap errors | ||||||
|  |         LDAPException, | ||||||
|  |         # Docker errors | ||||||
|  |         DockerException, | ||||||
|  |         # End-user errors | ||||||
|  |         Http404, | ||||||
|  |         # AsyncIO | ||||||
|  |         CancelledError, | ||||||
|  |     ) | ||||||
|     exc_value = None |     exc_value = None | ||||||
|     if "exc_info" in hint: |     if "exc_info" in hint: | ||||||
|         _, exc_value, _ = hint["exc_info"] |         _, exc_value, _ = hint["exc_info"] | ||||||
|         if should_ignore_exception(exc_value): |         if isinstance(exc_value, ignored_classes): | ||||||
|             LOGGER.debug("dropping exception", exc=exc_value) |             LOGGER.debug("dropping exception", exc=exc_value) | ||||||
|             return None |             return None | ||||||
|     if "logger" in event: |     if "logger" in event: | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
|  |  | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
|  |  | ||||||
| from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception | from authentik.lib.sentry import SentryIgnoredException, before_send | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestSentry(TestCase): | class TestSentry(TestCase): | ||||||
| @ -10,8 +10,8 @@ class TestSentry(TestCase): | |||||||
|  |  | ||||||
|     def test_error_not_sent(self): |     def test_error_not_sent(self): | ||||||
|         """Test SentryIgnoredError not sent""" |         """Test SentryIgnoredError not sent""" | ||||||
|         self.assertTrue(should_ignore_exception(SentryIgnoredException())) |         self.assertIsNone(before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)})) | ||||||
|  |  | ||||||
|     def test_error_sent(self): |     def test_error_sent(self): | ||||||
|         """Test error sent""" |         """Test error sent""" | ||||||
|         self.assertFalse(should_ignore_exception(ValueError())) |         self.assertEqual({}, before_send({}, {"exc_info": (0, ValueError(), 0)})) | ||||||
|  | |||||||
| @ -1,13 +1,15 @@ | |||||||
| """authentik outpost signals""" | """authentik outpost signals""" | ||||||
|  |  | ||||||
|  | from django.contrib.auth.signals import user_logged_out | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.db.models import Model | from django.db.models import Model | ||||||
| from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save | from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
|  | from django.http import HttpRequest | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.brands.models import Brand | from authentik.brands.models import Brand | ||||||
| from authentik.core.models import AuthenticatedSession, Provider | from authentik.core.models import AuthenticatedSession, Provider, User | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.lib.utils.reflection import class_to_path | from authentik.lib.utils.reflection import class_to_path | ||||||
| from authentik.outposts.models import Outpost, OutpostServiceConnection | from authentik.outposts.models import Outpost, OutpostServiceConnection | ||||||
| @ -80,6 +82,14 @@ def pre_delete_cleanup(sender, instance: Outpost, **_): | |||||||
|     outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) |     outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(user_logged_out) | ||||||
|  | def logout_revoke_direct(sender: type[User], request: HttpRequest, **_): | ||||||
|  |     """Catch logout by direct logout and forward to providers""" | ||||||
|  |     if not request.session or not request.session.session_key: | ||||||
|  |         return | ||||||
|  |     outpost_session_end.delay(request.session.session_key) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_delete, sender=AuthenticatedSession) | @receiver(pre_delete, sender=AuthenticatedSession) | ||||||
| def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | ||||||
|     """Catch logout by expiring sessions being deleted""" |     """Catch logout by expiring sessions being deleted""" | ||||||
|  | |||||||
| @ -39,3 +39,4 @@ class AuthentikPoliciesConfig(ManagedAppConfig): | |||||||
|     label = "authentik_policies" |     label = "authentik_policies" | ||||||
|     verbose_name = "authentik Policies" |     verbose_name = "authentik Policies" | ||||||
|     default = True |     default = True | ||||||
|  |     mountpoint = "policy/" | ||||||
|  | |||||||
							
								
								
									
										89
									
								
								authentik/policies/templates/policies/buffer.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								authentik/policies/templates/policies/buffer.html
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,89 @@ | |||||||
|  | {% extends 'login/base_full.html' %} | ||||||
|  |  | ||||||
|  | {% load static %} | ||||||
|  | {% load i18n %} | ||||||
|  |  | ||||||
|  | {% block head %} | ||||||
|  | {{ block.super }} | ||||||
|  | <script> | ||||||
|  |   let redirecting = false; | ||||||
|  |   const checkAuth = async () => { | ||||||
|  |     if (redirecting) return true; | ||||||
|  |     const url = "{{ check_auth_url }}"; | ||||||
|  |     console.debug("authentik/policies/buffer: Checking authentication..."); | ||||||
|  |     try { | ||||||
|  |       const result = await fetch(url, { | ||||||
|  |         method: "HEAD", | ||||||
|  |       }); | ||||||
|  |       if (result.status >= 400) { | ||||||
|  |         return false | ||||||
|  |       } | ||||||
|  |       console.debug("authentik/policies/buffer: Continuing"); | ||||||
|  |       redirecting = true; | ||||||
|  |       if ("{{ auth_req_method }}" === "post") { | ||||||
|  |         document.querySelector("form").submit(); | ||||||
|  |       } else { | ||||||
|  |         window.location.assign("{{ continue_url|escapejs }}"); | ||||||
|  |       } | ||||||
|  |     } catch { | ||||||
|  |       return false; | ||||||
|  |     } | ||||||
|  |   }; | ||||||
|  |   let timeout = 100; | ||||||
|  |   let offset = 20; | ||||||
|  |   let attempt = 0; | ||||||
|  |   const main = async () => { | ||||||
|  |     attempt += 1; | ||||||
|  |     await checkAuth(); | ||||||
|  |     console.debug(`authentik/policies/buffer: Waiting ${timeout}ms...`); | ||||||
|  |     setTimeout(main, timeout); | ||||||
|  |     timeout += (offset * attempt); | ||||||
|  |     if (timeout >= 2000) { | ||||||
|  |       timeout = 2000; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   document.addEventListener("visibilitychange", async () => { | ||||||
|  |     if (document.hidden) return; | ||||||
|  |     console.debug("authentik/policies/buffer: Checking authentication on tab activate..."); | ||||||
|  |     await checkAuth(); | ||||||
|  |   }); | ||||||
|  |   main(); | ||||||
|  | </script> | ||||||
|  | {% endblock %} | ||||||
|  |  | ||||||
|  | {% block title %} | ||||||
|  | {% trans 'Waiting for authentication...' %} - {{ brand.branding_title }} | ||||||
|  | {% endblock %} | ||||||
|  |  | ||||||
|  | {% block card_title %} | ||||||
|  | {% trans 'Waiting for authentication...' %} | ||||||
|  | {% endblock %} | ||||||
|  |  | ||||||
|  | {% block card %} | ||||||
|  | <form class="pf-c-form" method="{{ auth_req_method }}" action="{{ continue_url }}"> | ||||||
|  |   {% if auth_req_method == "post" %} | ||||||
|  |     {% for key, value in auth_req_body.items %} | ||||||
|  |       <input type="hidden" name="{{ key }}" value="{{ value }}" /> | ||||||
|  |     {% endfor %} | ||||||
|  |   {% endif %} | ||||||
|  |   <div class="pf-c-empty-state"> | ||||||
|  |     <div class="pf-c-empty-state__content"> | ||||||
|  |       <div class="pf-c-empty-state__icon"> | ||||||
|  |         <span class="pf-c-spinner pf-m-xl" role="progressbar"> | ||||||
|  |           <span class="pf-c-spinner__clipper"></span> | ||||||
|  |           <span class="pf-c-spinner__lead-ball"></span> | ||||||
|  |           <span class="pf-c-spinner__tail-ball"></span> | ||||||
|  |         </span> | ||||||
|  |       </div> | ||||||
|  |       <h1 class="pf-c-title pf-m-lg"> | ||||||
|  |         {% trans "You're already authenticating in another tab. This page will refresh once authentication is completed." %} | ||||||
|  |       </h1> | ||||||
|  |     </div> | ||||||
|  |   </div> | ||||||
|  |   <div class="pf-c-form__group pf-m-action"> | ||||||
|  |     <a href="{{ auth_req_url }}" class="pf-c-button pf-m-primary pf-m-block"> | ||||||
|  |       {% trans "Authenticate in this tab" %} | ||||||
|  |     </a> | ||||||
|  |   </div> | ||||||
|  | </form> | ||||||
|  | {% endblock %} | ||||||
							
								
								
									
										121
									
								
								authentik/policies/tests/test_views.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								authentik/policies/tests/test_views.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,121 @@ | |||||||
|  | from django.contrib.auth.models import AnonymousUser | ||||||
|  | from django.contrib.sessions.middleware import SessionMiddleware | ||||||
|  | from django.http import HttpResponse | ||||||
|  | from django.test import RequestFactory, TestCase | ||||||
|  | from django.urls import reverse | ||||||
|  |  | ||||||
|  | from authentik.core.models import Application, Provider | ||||||
|  | from authentik.core.tests.utils import create_test_flow, create_test_user | ||||||
|  | from authentik.flows.models import FlowDesignation | ||||||
|  | from authentik.flows.planner import FlowPlan | ||||||
|  | from authentik.flows.views.executor import SESSION_KEY_PLAN | ||||||
|  | from authentik.lib.generators import generate_id | ||||||
|  | from authentik.lib.tests.utils import dummy_get_response | ||||||
|  | from authentik.policies.views import ( | ||||||
|  |     QS_BUFFER_ID, | ||||||
|  |     SESSION_KEY_BUFFER, | ||||||
|  |     BufferedPolicyAccessView, | ||||||
|  |     BufferView, | ||||||
|  |     PolicyAccessView, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestPolicyViews(TestCase): | ||||||
|  |     """Test PolicyAccessView""" | ||||||
|  |  | ||||||
|  |     def setUp(self): | ||||||
|  |         super().setUp() | ||||||
|  |         self.factory = RequestFactory() | ||||||
|  |         self.user = create_test_user() | ||||||
|  |  | ||||||
|  |     def test_pav(self): | ||||||
|  |         """Test simple policy access view""" | ||||||
|  |         provider = Provider.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |         ) | ||||||
|  |         app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider) | ||||||
|  |  | ||||||
|  |         class TestView(PolicyAccessView): | ||||||
|  |             def resolve_provider_application(self): | ||||||
|  |                 self.provider = provider | ||||||
|  |                 self.application = app | ||||||
|  |  | ||||||
|  |             def get(self, *args, **kwargs): | ||||||
|  |                 return HttpResponse("foo") | ||||||
|  |  | ||||||
|  |         req = self.factory.get("/") | ||||||
|  |         req.user = self.user | ||||||
|  |         res = TestView.as_view()(req) | ||||||
|  |         self.assertEqual(res.status_code, 200) | ||||||
|  |         self.assertEqual(res.content, b"foo") | ||||||
|  |  | ||||||
|  |     def test_pav_buffer(self): | ||||||
|  |         """Test simple policy access view""" | ||||||
|  |         provider = Provider.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |         ) | ||||||
|  |         app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider) | ||||||
|  |         flow = create_test_flow(FlowDesignation.AUTHENTICATION) | ||||||
|  |  | ||||||
|  |         class TestView(BufferedPolicyAccessView): | ||||||
|  |             def resolve_provider_application(self): | ||||||
|  |                 self.provider = provider | ||||||
|  |                 self.application = app | ||||||
|  |  | ||||||
|  |             def get(self, *args, **kwargs): | ||||||
|  |                 return HttpResponse("foo") | ||||||
|  |  | ||||||
|  |         req = self.factory.get("/") | ||||||
|  |         req.user = AnonymousUser() | ||||||
|  |         middleware = SessionMiddleware(dummy_get_response) | ||||||
|  |         middleware.process_request(req) | ||||||
|  |         req.session[SESSION_KEY_PLAN] = FlowPlan(flow.pk) | ||||||
|  |         req.session.save() | ||||||
|  |         res = TestView.as_view()(req) | ||||||
|  |         self.assertEqual(res.status_code, 302) | ||||||
|  |         self.assertTrue(res.url.startswith(reverse("authentik_policies:buffer"))) | ||||||
|  |  | ||||||
|  |     def test_pav_buffer_skip(self): | ||||||
|  |         """Test simple policy access view (skip buffer)""" | ||||||
|  |         provider = Provider.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |         ) | ||||||
|  |         app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider) | ||||||
|  |         flow = create_test_flow(FlowDesignation.AUTHENTICATION) | ||||||
|  |  | ||||||
|  |         class TestView(BufferedPolicyAccessView): | ||||||
|  |             def resolve_provider_application(self): | ||||||
|  |                 self.provider = provider | ||||||
|  |                 self.application = app | ||||||
|  |  | ||||||
|  |             def get(self, *args, **kwargs): | ||||||
|  |                 return HttpResponse("foo") | ||||||
|  |  | ||||||
|  |         req = self.factory.get("/?skip_buffer=true") | ||||||
|  |         req.user = AnonymousUser() | ||||||
|  |         middleware = SessionMiddleware(dummy_get_response) | ||||||
|  |         middleware.process_request(req) | ||||||
|  |         req.session[SESSION_KEY_PLAN] = FlowPlan(flow.pk) | ||||||
|  |         req.session.save() | ||||||
|  |         res = TestView.as_view()(req) | ||||||
|  |         self.assertEqual(res.status_code, 302) | ||||||
|  |         self.assertTrue(res.url.startswith(reverse("authentik_flows:default-authentication"))) | ||||||
|  |  | ||||||
|  |     def test_buffer(self): | ||||||
|  |         """Test buffer view""" | ||||||
|  |         uid = generate_id() | ||||||
|  |         req = self.factory.get(f"/?{QS_BUFFER_ID}={uid}") | ||||||
|  |         req.user = AnonymousUser() | ||||||
|  |         middleware = SessionMiddleware(dummy_get_response) | ||||||
|  |         middleware.process_request(req) | ||||||
|  |         ts = generate_id() | ||||||
|  |         req.session[SESSION_KEY_BUFFER % uid] = { | ||||||
|  |             "method": "get", | ||||||
|  |             "body": {}, | ||||||
|  |             "url": f"/{ts}", | ||||||
|  |         } | ||||||
|  |         req.session.save() | ||||||
|  |  | ||||||
|  |         res = BufferView.as_view()(req) | ||||||
|  |         self.assertEqual(res.status_code, 200) | ||||||
|  |         self.assertIn(ts, res.render().content.decode()) | ||||||
| @ -1,7 +1,14 @@ | |||||||
| """API URLs""" | """API URLs""" | ||||||
|  |  | ||||||
|  | from django.urls import path | ||||||
|  |  | ||||||
| from authentik.policies.api.bindings import PolicyBindingViewSet | from authentik.policies.api.bindings import PolicyBindingViewSet | ||||||
| from authentik.policies.api.policies import PolicyViewSet | from authentik.policies.api.policies import PolicyViewSet | ||||||
|  | from authentik.policies.views import BufferView | ||||||
|  |  | ||||||
|  | urlpatterns = [ | ||||||
|  |     path("buffer", BufferView.as_view(), name="buffer"), | ||||||
|  | ] | ||||||
|  |  | ||||||
| api_urlpatterns = [ | api_urlpatterns = [ | ||||||
|     ("policies/all", PolicyViewSet), |     ("policies/all", PolicyViewSet), | ||||||
|  | |||||||
| @ -1,23 +1,37 @@ | |||||||
| """authentik access helper classes""" | """authentik access helper classes""" | ||||||
|  |  | ||||||
| from typing import Any | from typing import Any | ||||||
|  | from uuid import uuid4 | ||||||
|  |  | ||||||
| from django.contrib import messages | from django.contrib import messages | ||||||
| from django.contrib.auth.mixins import AccessMixin | from django.contrib.auth.mixins import AccessMixin | ||||||
| from django.contrib.auth.views import redirect_to_login | from django.contrib.auth.views import redirect_to_login | ||||||
| from django.http import HttpRequest, HttpResponse | from django.http import HttpRequest, HttpResponse, QueryDict | ||||||
|  | from django.shortcuts import redirect | ||||||
|  | from django.urls import reverse | ||||||
|  | from django.utils.http import urlencode | ||||||
| from django.utils.translation import gettext as _ | from django.utils.translation import gettext as _ | ||||||
| from django.views.generic.base import View | from django.views.generic.base import TemplateView, View | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.models import Application, Provider, User | from authentik.core.models import Application, Provider, User | ||||||
| from authentik.flows.views.executor import SESSION_KEY_APPLICATION_PRE, SESSION_KEY_POST | from authentik.flows.models import Flow, FlowDesignation | ||||||
|  | from authentik.flows.planner import FlowPlan | ||||||
|  | from authentik.flows.views.executor import ( | ||||||
|  |     SESSION_KEY_APPLICATION_PRE, | ||||||
|  |     SESSION_KEY_AUTH_STARTED, | ||||||
|  |     SESSION_KEY_PLAN, | ||||||
|  |     SESSION_KEY_POST, | ||||||
|  | ) | ||||||
| from authentik.lib.sentry import SentryIgnoredException | from authentik.lib.sentry import SentryIgnoredException | ||||||
| from authentik.policies.denied import AccessDeniedResponse | from authentik.policies.denied import AccessDeniedResponse | ||||||
| from authentik.policies.engine import PolicyEngine | from authentik.policies.engine import PolicyEngine | ||||||
| from authentik.policies.types import PolicyRequest, PolicyResult | from authentik.policies.types import PolicyRequest, PolicyResult | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  | QS_BUFFER_ID = "af_bf_id" | ||||||
|  | QS_SKIP_BUFFER = "skip_buffer" | ||||||
|  | SESSION_KEY_BUFFER = "authentik/policies/pav_buffer/%s" | ||||||
|  |  | ||||||
|  |  | ||||||
| class RequestValidationError(SentryIgnoredException): | class RequestValidationError(SentryIgnoredException): | ||||||
| @ -125,3 +139,65 @@ class PolicyAccessView(AccessMixin, View): | |||||||
|             for message in result.messages: |             for message in result.messages: | ||||||
|                 messages.error(self.request, _(message)) |                 messages.error(self.request, _(message)) | ||||||
|         return result |         return result | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def url_with_qs(url: str, **kwargs): | ||||||
|  |     """Update/set querystring of `url` with the parameters in `kwargs`. Original query string | ||||||
|  |     parameters are retained""" | ||||||
|  |     if "?" not in url: | ||||||
|  |         return url + f"?{urlencode(kwargs)}" | ||||||
|  |     url, _, qs = url.partition("?") | ||||||
|  |     qs = QueryDict(qs, mutable=True) | ||||||
|  |     qs.update(kwargs) | ||||||
|  |     return url + f"?{urlencode(qs.items())}" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BufferView(TemplateView): | ||||||
|  |     """Buffer view""" | ||||||
|  |  | ||||||
|  |     template_name = "policies/buffer.html" | ||||||
|  |  | ||||||
|  |     def get_context_data(self, **kwargs): | ||||||
|  |         buf_id = self.request.GET.get(QS_BUFFER_ID) | ||||||
|  |         buffer: dict = self.request.session.get(SESSION_KEY_BUFFER % buf_id) | ||||||
|  |         kwargs["auth_req_method"] = buffer["method"] | ||||||
|  |         kwargs["auth_req_body"] = buffer["body"] | ||||||
|  |         kwargs["auth_req_url"] = url_with_qs(buffer["url"], **{QS_SKIP_BUFFER: True}) | ||||||
|  |         kwargs["check_auth_url"] = reverse("authentik_api:user-me") | ||||||
|  |         kwargs["continue_url"] = url_with_qs(buffer["url"], **{QS_BUFFER_ID: buf_id}) | ||||||
|  |         return super().get_context_data(**kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BufferedPolicyAccessView(PolicyAccessView): | ||||||
|  |     """PolicyAccessView which buffers access requests in case the user is not logged in""" | ||||||
|  |  | ||||||
|  |     def handle_no_permission(self): | ||||||
|  |         plan: FlowPlan | None = self.request.session.get(SESSION_KEY_PLAN) | ||||||
|  |         authenticating = self.request.session.get(SESSION_KEY_AUTH_STARTED) | ||||||
|  |         if plan: | ||||||
|  |             flow = Flow.objects.filter(pk=plan.flow_pk).first() | ||||||
|  |             if not flow or flow.designation != FlowDesignation.AUTHENTICATION: | ||||||
|  |                 LOGGER.debug("Not buffering request, no flow or flow not for authentication") | ||||||
|  |                 return super().handle_no_permission() | ||||||
|  |         if not plan and authenticating is None: | ||||||
|  |             LOGGER.debug("Not buffering request, no flow plan active") | ||||||
|  |             return super().handle_no_permission() | ||||||
|  |         if self.request.GET.get(QS_SKIP_BUFFER): | ||||||
|  |             LOGGER.debug("Not buffering request, explicit skip") | ||||||
|  |             return super().handle_no_permission() | ||||||
|  |         buffer_id = str(uuid4()) | ||||||
|  |         LOGGER.debug("Buffering access request", bf_id=buffer_id) | ||||||
|  |         self.request.session[SESSION_KEY_BUFFER % buffer_id] = { | ||||||
|  |             "body": self.request.POST, | ||||||
|  |             "url": self.request.build_absolute_uri(self.request.get_full_path()), | ||||||
|  |             "method": self.request.method.lower(), | ||||||
|  |         } | ||||||
|  |         return redirect( | ||||||
|  |             url_with_qs(reverse("authentik_policies:buffer"), **{QS_BUFFER_ID: buffer_id}) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def dispatch(self, request, *args, **kwargs): | ||||||
|  |         response = super().dispatch(request, *args, **kwargs) | ||||||
|  |         if QS_BUFFER_ID in self.request.GET: | ||||||
|  |             self.request.session.pop(SESSION_KEY_BUFFER % self.request.GET[QS_BUFFER_ID], None) | ||||||
|  |         return response | ||||||
|  | |||||||
| @ -1,10 +1,23 @@ | |||||||
|  | from django.contrib.auth.signals import user_logged_out | ||||||
| from django.db.models.signals import post_save, pre_delete | from django.db.models.signals import post_save, pre_delete | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
|  | from django.http import HttpRequest | ||||||
|  |  | ||||||
| from authentik.core.models import AuthenticatedSession, User | from authentik.core.models import AuthenticatedSession, User | ||||||
| from authentik.providers.oauth2.models import AccessToken, DeviceToken, RefreshToken | from authentik.providers.oauth2.models import AccessToken, DeviceToken, RefreshToken | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(user_logged_out) | ||||||
|  | def user_logged_out_oauth_tokens_removal(sender, request: HttpRequest, user: User, **_): | ||||||
|  |     """Revoke tokens upon user logout""" | ||||||
|  |     if not request.session or not request.session.session_key: | ||||||
|  |         return | ||||||
|  |     AccessToken.objects.filter( | ||||||
|  |         user=user, | ||||||
|  |         session__session__session_key=request.session.session_key, | ||||||
|  |     ).delete() | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_delete, sender=AuthenticatedSession) | @receiver(pre_delete, sender=AuthenticatedSession) | ||||||
| def user_session_deleted_oauth_tokens_removal(sender, instance: AuthenticatedSession, **_): | def user_session_deleted_oauth_tokens_removal(sender, instance: AuthenticatedSession, **_): | ||||||
|     """Revoke tokens upon user logout""" |     """Revoke tokens upon user logout""" | ||||||
|  | |||||||
| @ -387,7 +387,8 @@ class TestAuthorize(OAuthTestCase): | |||||||
|             self.assertEqual( |             self.assertEqual( | ||||||
|                 response.url, |                 response.url, | ||||||
|                 ( |                 ( | ||||||
|                     f"http://localhost#id_token={provider.encode(token.id_token.to_dict())}" |                     f"http://localhost#access_token={token.token}" | ||||||
|  |                     f"&id_token={provider.encode(token.id_token.to_dict())}" | ||||||
|                     f"&token_type={TOKEN_TYPE}" |                     f"&token_type={TOKEN_TYPE}" | ||||||
|                     f"&expires_in={int(expires)}&state={state}" |                     f"&expires_in={int(expires)}&state={state}" | ||||||
|                 ), |                 ), | ||||||
| @ -562,6 +563,7 @@ class TestAuthorize(OAuthTestCase): | |||||||
|                 "url": "http://localhost", |                 "url": "http://localhost", | ||||||
|                 "title": f"Redirecting to {app.name}...", |                 "title": f"Redirecting to {app.name}...", | ||||||
|                 "attrs": { |                 "attrs": { | ||||||
|  |                     "access_token": token.token, | ||||||
|                     "id_token": provider.encode(token.id_token.to_dict()), |                     "id_token": provider.encode(token.id_token.to_dict()), | ||||||
|                     "token_type": TOKEN_TYPE, |                     "token_type": TOKEN_TYPE, | ||||||
|                     "expires_in": "3600", |                     "expires_in": "3600", | ||||||
|  | |||||||
| @ -30,7 +30,7 @@ from authentik.flows.stage import StageView | |||||||
| from authentik.lib.utils.time import timedelta_from_string | from authentik.lib.utils.time import timedelta_from_string | ||||||
| from authentik.lib.views import bad_request_message | from authentik.lib.views import bad_request_message | ||||||
| from authentik.policies.types import PolicyRequest | from authentik.policies.types import PolicyRequest | ||||||
| from authentik.policies.views import PolicyAccessView, RequestValidationError | from authentik.policies.views import BufferedPolicyAccessView, RequestValidationError | ||||||
| from authentik.providers.oauth2.constants import ( | from authentik.providers.oauth2.constants import ( | ||||||
|     PKCE_METHOD_PLAIN, |     PKCE_METHOD_PLAIN, | ||||||
|     PKCE_METHOD_S256, |     PKCE_METHOD_S256, | ||||||
| @ -150,12 +150,12 @@ class OAuthAuthorizationParams: | |||||||
|         self.check_redirect_uri() |         self.check_redirect_uri() | ||||||
|         self.check_grant() |         self.check_grant() | ||||||
|         self.check_scope(github_compat) |         self.check_scope(github_compat) | ||||||
|  |         self.check_nonce() | ||||||
|  |         self.check_code_challenge() | ||||||
|         if self.request: |         if self.request: | ||||||
|             raise AuthorizeError( |             raise AuthorizeError( | ||||||
|                 self.redirect_uri, "request_not_supported", self.grant_type, self.state |                 self.redirect_uri, "request_not_supported", self.grant_type, self.state | ||||||
|             ) |             ) | ||||||
|         self.check_nonce() |  | ||||||
|         self.check_code_challenge() |  | ||||||
|  |  | ||||||
|     def check_grant(self): |     def check_grant(self): | ||||||
|         """Check grant""" |         """Check grant""" | ||||||
| @ -326,7 +326,7 @@ class OAuthAuthorizationParams: | |||||||
|         return code |         return code | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthorizationFlowInitView(PolicyAccessView): | class AuthorizationFlowInitView(BufferedPolicyAccessView): | ||||||
|     """OAuth2 Flow initializer, checks access to application and starts flow""" |     """OAuth2 Flow initializer, checks access to application and starts flow""" | ||||||
|  |  | ||||||
|     params: OAuthAuthorizationParams |     params: OAuthAuthorizationParams | ||||||
| @ -630,6 +630,7 @@ class OAuthFulfillmentStage(StageView): | |||||||
|         if self.params.response_type in [ |         if self.params.response_type in [ | ||||||
|             ResponseTypes.ID_TOKEN_TOKEN, |             ResponseTypes.ID_TOKEN_TOKEN, | ||||||
|             ResponseTypes.CODE_ID_TOKEN_TOKEN, |             ResponseTypes.CODE_ID_TOKEN_TOKEN, | ||||||
|  |             ResponseTypes.ID_TOKEN, | ||||||
|             ResponseTypes.CODE_TOKEN, |             ResponseTypes.CODE_TOKEN, | ||||||
|         ]: |         ]: | ||||||
|             query_fragment["access_token"] = token.token |             query_fragment["access_token"] = token.token | ||||||
|  | |||||||
| @ -2,11 +2,13 @@ | |||||||
|  |  | ||||||
| from asgiref.sync import async_to_sync | from asgiref.sync import async_to_sync | ||||||
| from channels.layers import get_channel_layer | from channels.layers import get_channel_layer | ||||||
|  | from django.contrib.auth.signals import user_logged_out | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.db.models.signals import post_delete, post_save, pre_delete | from django.db.models.signals import post_delete, post_save, pre_delete | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
|  | from django.http import HttpRequest | ||||||
|  |  | ||||||
| from authentik.core.models import AuthenticatedSession | from authentik.core.models import AuthenticatedSession, User | ||||||
| from authentik.providers.rac.api.endpoints import user_endpoint_cache_key | from authentik.providers.rac.api.endpoints import user_endpoint_cache_key | ||||||
| from authentik.providers.rac.consumer_client import ( | from authentik.providers.rac.consumer_client import ( | ||||||
|     RAC_CLIENT_GROUP_SESSION, |     RAC_CLIENT_GROUP_SESSION, | ||||||
| @ -15,6 +17,21 @@ from authentik.providers.rac.consumer_client import ( | |||||||
| from authentik.providers.rac.models import ConnectionToken, Endpoint | from authentik.providers.rac.models import ConnectionToken, Endpoint | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @receiver(user_logged_out) | ||||||
|  | def user_logged_out_session(sender, request: HttpRequest, user: User, **_): | ||||||
|  |     """Disconnect any open RAC connections""" | ||||||
|  |     if not request.session or not request.session.session_key: | ||||||
|  |         return | ||||||
|  |     layer = get_channel_layer() | ||||||
|  |     async_to_sync(layer.group_send)( | ||||||
|  |         RAC_CLIENT_GROUP_SESSION | ||||||
|  |         % { | ||||||
|  |             "session": request.session.session_key, | ||||||
|  |         }, | ||||||
|  |         {"type": "event.disconnect", "reason": "session_logout"}, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_delete, sender=AuthenticatedSession) | @receiver(pre_delete, sender=AuthenticatedSession) | ||||||
| def user_session_deleted(sender, instance: AuthenticatedSession, **_): | def user_session_deleted(sender, instance: AuthenticatedSession, **_): | ||||||
|     layer = get_channel_layer() |     layer = get_channel_layer() | ||||||
|  | |||||||
| @ -49,7 +49,6 @@ class TestEndpointsAPI(APITestCase): | |||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             response.content.decode(), |             response.content.decode(), | ||||||
|             { |             { | ||||||
|                 "autocomplete": {}, |  | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                     "next": 0, |                     "next": 0, | ||||||
|                     "previous": 0, |                     "previous": 0, | ||||||
| @ -102,7 +101,6 @@ class TestEndpointsAPI(APITestCase): | |||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             response.content.decode(), |             response.content.decode(), | ||||||
|             { |             { | ||||||
|                 "autocomplete": {}, |  | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                     "next": 0, |                     "next": 0, | ||||||
|                     "previous": 0, |                     "previous": 0, | ||||||
|  | |||||||
| @ -18,14 +18,11 @@ from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, FlowPlanner | |||||||
| from authentik.flows.stage import RedirectStage | from authentik.flows.stage import RedirectStage | ||||||
| from authentik.lib.utils.time import timedelta_from_string | from authentik.lib.utils.time import timedelta_from_string | ||||||
| from authentik.policies.engine import PolicyEngine | from authentik.policies.engine import PolicyEngine | ||||||
| from authentik.policies.views import PolicyAccessView | from authentik.policies.views import BufferedPolicyAccessView | ||||||
| from authentik.providers.rac.models import ConnectionToken, Endpoint, RACProvider | from authentik.providers.rac.models import ConnectionToken, Endpoint, RACProvider | ||||||
| from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT |  | ||||||
|  |  | ||||||
| PLAN_CONNECTION_SETTINGS = "connection_settings" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class RACStartView(PolicyAccessView): | class RACStartView(BufferedPolicyAccessView): | ||||||
|     """Start a RAC connection by checking access and creating a connection token""" |     """Start a RAC connection by checking access and creating a connection token""" | ||||||
|  |  | ||||||
|     endpoint: Endpoint |     endpoint: Endpoint | ||||||
| @ -112,15 +109,10 @@ class RACFinalStage(RedirectStage): | |||||||
|         return super().dispatch(request, *args, **kwargs) |         return super().dispatch(request, *args, **kwargs) | ||||||
|  |  | ||||||
|     def get_challenge(self, *args, **kwargs) -> RedirectChallenge: |     def get_challenge(self, *args, **kwargs) -> RedirectChallenge: | ||||||
|         settings = self.executor.plan.context.get(PLAN_CONNECTION_SETTINGS) |  | ||||||
|         if not settings: |  | ||||||
|             settings = self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}).get( |  | ||||||
|                 PLAN_CONNECTION_SETTINGS |  | ||||||
|             ) |  | ||||||
|         token = ConnectionToken.objects.create( |         token = ConnectionToken.objects.create( | ||||||
|             provider=self.provider, |             provider=self.provider, | ||||||
|             endpoint=self.endpoint, |             endpoint=self.endpoint, | ||||||
|             settings=settings or {}, |             settings=self.executor.plan.context.get("connection_settings", {}), | ||||||
|             session=self.request.session["authenticatedsession"], |             session=self.request.session["authenticatedsession"], | ||||||
|             expires=now() + timedelta_from_string(self.provider.connection_expiry), |             expires=now() + timedelta_from_string(self.provider.connection_expiry), | ||||||
|             expiring=True, |             expiring=True, | ||||||
|  | |||||||
| @ -35,8 +35,8 @@ REQUEST_KEY_SAML_SIG_ALG = "SigAlg" | |||||||
| REQUEST_KEY_SAML_RESPONSE = "SAMLResponse" | REQUEST_KEY_SAML_RESPONSE = "SAMLResponse" | ||||||
| REQUEST_KEY_RELAY_STATE = "RelayState" | REQUEST_KEY_RELAY_STATE = "RelayState" | ||||||
|  |  | ||||||
| SESSION_KEY_AUTH_N_REQUEST = "authentik/providers/saml/authn_request" | PLAN_CONTEXT_SAML_AUTH_N_REQUEST = "authentik/providers/saml/authn_request" | ||||||
| SESSION_KEY_LOGOUT_REQUEST = "authentik/providers/saml/logout_request" | PLAN_CONTEXT_SAML_LOGOUT_REQUEST = "authentik/providers/saml/logout_request" | ||||||
|  |  | ||||||
|  |  | ||||||
| # This View doesn't have a URL on purpose, as its called by the FlowExecutor | # This View doesn't have a URL on purpose, as its called by the FlowExecutor | ||||||
| @ -50,10 +50,11 @@ class SAMLFlowFinalView(ChallengeStageView): | |||||||
|     def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: |     def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: | ||||||
|         application: Application = self.executor.plan.context[PLAN_CONTEXT_APPLICATION] |         application: Application = self.executor.plan.context[PLAN_CONTEXT_APPLICATION] | ||||||
|         provider: SAMLProvider = get_object_or_404(SAMLProvider, pk=application.provider_id) |         provider: SAMLProvider = get_object_or_404(SAMLProvider, pk=application.provider_id) | ||||||
|         if SESSION_KEY_AUTH_N_REQUEST not in self.request.session: |         if PLAN_CONTEXT_SAML_AUTH_N_REQUEST not in self.executor.plan.context: | ||||||
|  |             self.logger.warning("No AuthNRequest in context") | ||||||
|             return self.executor.stage_invalid() |             return self.executor.stage_invalid() | ||||||
|  |  | ||||||
|         auth_n_request: AuthNRequest = self.request.session.pop(SESSION_KEY_AUTH_N_REQUEST) |         auth_n_request: AuthNRequest = self.executor.plan.context[PLAN_CONTEXT_SAML_AUTH_N_REQUEST] | ||||||
|         try: |         try: | ||||||
|             response = AssertionProcessor(provider, request, auth_n_request).build_response() |             response = AssertionProcessor(provider, request, auth_n_request).build_response() | ||||||
|         except SAMLException as exc: |         except SAMLException as exc: | ||||||
| @ -106,6 +107,3 @@ class SAMLFlowFinalView(ChallengeStageView): | |||||||
|     def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: |     def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: | ||||||
|         # We'll never get here since the challenge redirects to the SP |         # We'll never get here since the challenge redirects to the SP | ||||||
|         return HttpResponseBadRequest() |         return HttpResponseBadRequest() | ||||||
|  |  | ||||||
|     def cleanup(self): |  | ||||||
|         self.request.session.pop(SESSION_KEY_AUTH_N_REQUEST, None) |  | ||||||
|  | |||||||
| @ -19,9 +19,9 @@ from authentik.providers.saml.exceptions import CannotHandleAssertion | |||||||
| from authentik.providers.saml.models import SAMLProvider | from authentik.providers.saml.models import SAMLProvider | ||||||
| from authentik.providers.saml.processors.logout_request_parser import LogoutRequestParser | from authentik.providers.saml.processors.logout_request_parser import LogoutRequestParser | ||||||
| from authentik.providers.saml.views.flows import ( | from authentik.providers.saml.views.flows import ( | ||||||
|  |     PLAN_CONTEXT_SAML_LOGOUT_REQUEST, | ||||||
|     REQUEST_KEY_RELAY_STATE, |     REQUEST_KEY_RELAY_STATE, | ||||||
|     REQUEST_KEY_SAML_REQUEST, |     REQUEST_KEY_SAML_REQUEST, | ||||||
|     SESSION_KEY_LOGOUT_REQUEST, |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -33,6 +33,10 @@ class SAMLSLOView(PolicyAccessView): | |||||||
|  |  | ||||||
|     flow: Flow |     flow: Flow | ||||||
|  |  | ||||||
|  |     def __init__(self, **kwargs): | ||||||
|  |         super().__init__(**kwargs) | ||||||
|  |         self.plan_context = {} | ||||||
|  |  | ||||||
|     def resolve_provider_application(self): |     def resolve_provider_application(self): | ||||||
|         self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"]) |         self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"]) | ||||||
|         self.provider: SAMLProvider = get_object_or_404( |         self.provider: SAMLProvider = get_object_or_404( | ||||||
| @ -59,6 +63,7 @@ class SAMLSLOView(PolicyAccessView): | |||||||
|             request, |             request, | ||||||
|             { |             { | ||||||
|                 PLAN_CONTEXT_APPLICATION: self.application, |                 PLAN_CONTEXT_APPLICATION: self.application, | ||||||
|  |                 **self.plan_context, | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|         plan.append_stage(in_memory_stage(SessionEndStage)) |         plan.append_stage(in_memory_stage(SessionEndStage)) | ||||||
| @ -83,7 +88,7 @@ class SAMLSLOBindingRedirectView(SAMLSLOView): | |||||||
|                 self.request.GET[REQUEST_KEY_SAML_REQUEST], |                 self.request.GET[REQUEST_KEY_SAML_REQUEST], | ||||||
|                 relay_state=self.request.GET.get(REQUEST_KEY_RELAY_STATE, None), |                 relay_state=self.request.GET.get(REQUEST_KEY_RELAY_STATE, None), | ||||||
|             ) |             ) | ||||||
|             self.request.session[SESSION_KEY_LOGOUT_REQUEST] = logout_request |             self.plan_context[PLAN_CONTEXT_SAML_LOGOUT_REQUEST] = logout_request | ||||||
|         except CannotHandleAssertion as exc: |         except CannotHandleAssertion as exc: | ||||||
|             Event.new( |             Event.new( | ||||||
|                 EventAction.CONFIGURATION_ERROR, |                 EventAction.CONFIGURATION_ERROR, | ||||||
| @ -111,7 +116,7 @@ class SAMLSLOBindingPOSTView(SAMLSLOView): | |||||||
|                 payload[REQUEST_KEY_SAML_REQUEST], |                 payload[REQUEST_KEY_SAML_REQUEST], | ||||||
|                 relay_state=payload.get(REQUEST_KEY_RELAY_STATE, None), |                 relay_state=payload.get(REQUEST_KEY_RELAY_STATE, None), | ||||||
|             ) |             ) | ||||||
|             self.request.session[SESSION_KEY_LOGOUT_REQUEST] = logout_request |             self.plan_context[PLAN_CONTEXT_SAML_LOGOUT_REQUEST] = logout_request | ||||||
|         except CannotHandleAssertion as exc: |         except CannotHandleAssertion as exc: | ||||||
|             LOGGER.info(str(exc)) |             LOGGER.info(str(exc)) | ||||||
|             return bad_request_message(self.request, str(exc)) |             return bad_request_message(self.request, str(exc)) | ||||||
|  | |||||||
| @ -15,16 +15,16 @@ from authentik.flows.models import in_memory_stage | |||||||
| from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO, FlowPlanner | from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO, FlowPlanner | ||||||
| from authentik.flows.views.executor import SESSION_KEY_POST | from authentik.flows.views.executor import SESSION_KEY_POST | ||||||
| from authentik.lib.views import bad_request_message | from authentik.lib.views import bad_request_message | ||||||
| from authentik.policies.views import PolicyAccessView | from authentik.policies.views import BufferedPolicyAccessView | ||||||
| from authentik.providers.saml.exceptions import CannotHandleAssertion | from authentik.providers.saml.exceptions import CannotHandleAssertion | ||||||
| from authentik.providers.saml.models import SAMLBindings, SAMLProvider | from authentik.providers.saml.models import SAMLBindings, SAMLProvider | ||||||
| from authentik.providers.saml.processors.authn_request_parser import AuthNRequestParser | from authentik.providers.saml.processors.authn_request_parser import AuthNRequestParser | ||||||
| from authentik.providers.saml.views.flows import ( | from authentik.providers.saml.views.flows import ( | ||||||
|  |     PLAN_CONTEXT_SAML_AUTH_N_REQUEST, | ||||||
|     REQUEST_KEY_RELAY_STATE, |     REQUEST_KEY_RELAY_STATE, | ||||||
|     REQUEST_KEY_SAML_REQUEST, |     REQUEST_KEY_SAML_REQUEST, | ||||||
|     REQUEST_KEY_SAML_SIG_ALG, |     REQUEST_KEY_SAML_SIG_ALG, | ||||||
|     REQUEST_KEY_SAML_SIGNATURE, |     REQUEST_KEY_SAML_SIGNATURE, | ||||||
|     SESSION_KEY_AUTH_N_REQUEST, |  | ||||||
|     SAMLFlowFinalView, |     SAMLFlowFinalView, | ||||||
| ) | ) | ||||||
| from authentik.stages.consent.stage import ( | from authentik.stages.consent.stage import ( | ||||||
| @ -35,10 +35,14 @@ from authentik.stages.consent.stage import ( | |||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
|  |  | ||||||
| class SAMLSSOView(PolicyAccessView): | class SAMLSSOView(BufferedPolicyAccessView): | ||||||
|     """SAML SSO Base View, which plans a flow and injects our final stage. |     """SAML SSO Base View, which plans a flow and injects our final stage. | ||||||
|     Calls get/post handler.""" |     Calls get/post handler.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, **kwargs): | ||||||
|  |         super().__init__(**kwargs) | ||||||
|  |         self.plan_context = {} | ||||||
|  |  | ||||||
|     def resolve_provider_application(self): |     def resolve_provider_application(self): | ||||||
|         self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"]) |         self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"]) | ||||||
|         self.provider: SAMLProvider = get_object_or_404( |         self.provider: SAMLProvider = get_object_or_404( | ||||||
| @ -68,6 +72,7 @@ class SAMLSSOView(PolicyAccessView): | |||||||
|                     PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.") |                     PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.") | ||||||
|                     % {"application": self.application.name}, |                     % {"application": self.application.name}, | ||||||
|                     PLAN_CONTEXT_CONSENT_PERMISSIONS: [], |                     PLAN_CONTEXT_CONSENT_PERMISSIONS: [], | ||||||
|  |                     **self.plan_context, | ||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
|         except FlowNonApplicableException: |         except FlowNonApplicableException: | ||||||
| @ -83,7 +88,7 @@ class SAMLSSOView(PolicyAccessView): | |||||||
|  |  | ||||||
|     def post(self, request: HttpRequest, application_slug: str) -> HttpResponse: |     def post(self, request: HttpRequest, application_slug: str) -> HttpResponse: | ||||||
|         """GET and POST use the same handler, but we can't |         """GET and POST use the same handler, but we can't | ||||||
|         override .dispatch easily because PolicyAccessView's dispatch""" |         override .dispatch easily because BufferedPolicyAccessView's dispatch""" | ||||||
|         return self.get(request, application_slug) |         return self.get(request, application_slug) | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -103,7 +108,7 @@ class SAMLSSOBindingRedirectView(SAMLSSOView): | |||||||
|                 self.request.GET.get(REQUEST_KEY_SAML_SIGNATURE), |                 self.request.GET.get(REQUEST_KEY_SAML_SIGNATURE), | ||||||
|                 self.request.GET.get(REQUEST_KEY_SAML_SIG_ALG), |                 self.request.GET.get(REQUEST_KEY_SAML_SIG_ALG), | ||||||
|             ) |             ) | ||||||
|             self.request.session[SESSION_KEY_AUTH_N_REQUEST] = auth_n_request |             self.plan_context[PLAN_CONTEXT_SAML_AUTH_N_REQUEST] = auth_n_request | ||||||
|         except CannotHandleAssertion as exc: |         except CannotHandleAssertion as exc: | ||||||
|             Event.new( |             Event.new( | ||||||
|                 EventAction.CONFIGURATION_ERROR, |                 EventAction.CONFIGURATION_ERROR, | ||||||
| @ -137,7 +142,7 @@ class SAMLSSOBindingPOSTView(SAMLSSOView): | |||||||
|                 payload[REQUEST_KEY_SAML_REQUEST], |                 payload[REQUEST_KEY_SAML_REQUEST], | ||||||
|                 payload.get(REQUEST_KEY_RELAY_STATE), |                 payload.get(REQUEST_KEY_RELAY_STATE), | ||||||
|             ) |             ) | ||||||
|             self.request.session[SESSION_KEY_AUTH_N_REQUEST] = auth_n_request |             self.plan_context[PLAN_CONTEXT_SAML_AUTH_N_REQUEST] = auth_n_request | ||||||
|         except CannotHandleAssertion as exc: |         except CannotHandleAssertion as exc: | ||||||
|             LOGGER.info(str(exc)) |             LOGGER.info(str(exc)) | ||||||
|             return bad_request_message(self.request, str(exc)) |             return bad_request_message(self.request, str(exc)) | ||||||
| @ -151,4 +156,4 @@ class SAMLSSOBindingInitView(SAMLSSOView): | |||||||
|         """Create SAML Response from scratch""" |         """Create SAML Response from scratch""" | ||||||
|         LOGGER.debug("No SAML Request, using IdP-initiated flow.") |         LOGGER.debug("No SAML Request, using IdP-initiated flow.") | ||||||
|         auth_n_request = AuthNRequestParser(self.provider).idp_initiated() |         auth_n_request = AuthNRequestParser(self.provider).idp_initiated() | ||||||
|         self.request.session[SESSION_KEY_AUTH_N_REQUEST] = auth_n_request |         self.plan_context[PLAN_CONTEXT_SAML_AUTH_N_REQUEST] = auth_n_request | ||||||
|  | |||||||
| @ -5,6 +5,7 @@ from itertools import batched | |||||||
| from django.db import transaction | from django.db import transaction | ||||||
| from pydantic import ValidationError | from pydantic import ValidationError | ||||||
| from pydanticscim.group import GroupMember | from pydanticscim.group import GroupMember | ||||||
|  | from pydanticscim.responses import PatchOp | ||||||
|  |  | ||||||
| from authentik.core.models import Group | from authentik.core.models import Group | ||||||
| from authentik.lib.sync.mapper import PropertyMappingManager | from authentik.lib.sync.mapper import PropertyMappingManager | ||||||
| @ -19,12 +20,7 @@ from authentik.providers.scim.clients.base import SCIMClient | |||||||
| from authentik.providers.scim.clients.exceptions import ( | from authentik.providers.scim.clients.exceptions import ( | ||||||
|     SCIMRequestException, |     SCIMRequestException, | ||||||
| ) | ) | ||||||
| from authentik.providers.scim.clients.schema import ( | from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest | ||||||
|     SCIM_GROUP_SCHEMA, |  | ||||||
|     PatchOp, |  | ||||||
|     PatchOperation, |  | ||||||
|     PatchRequest, |  | ||||||
| ) |  | ||||||
| from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema | from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema | ||||||
| from authentik.providers.scim.models import ( | from authentik.providers.scim.models import ( | ||||||
|     SCIMMapping, |     SCIMMapping, | ||||||
|  | |||||||
| @ -1,7 +1,5 @@ | |||||||
| """Custom SCIM schemas""" | """Custom SCIM schemas""" | ||||||
|  |  | ||||||
| from enum import Enum |  | ||||||
|  |  | ||||||
| from pydantic import Field | from pydantic import Field | ||||||
| from pydanticscim.group import Group as BaseGroup | from pydanticscim.group import Group as BaseGroup | ||||||
| from pydanticscim.responses import PatchOperation as BasePatchOperation | from pydanticscim.responses import PatchOperation as BasePatchOperation | ||||||
| @ -67,21 +65,6 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class PatchOp(str, Enum): |  | ||||||
|  |  | ||||||
|     replace = "replace" |  | ||||||
|     remove = "remove" |  | ||||||
|     add = "add" |  | ||||||
|  |  | ||||||
|     @classmethod |  | ||||||
|     def _missing_(cls, value): |  | ||||||
|         value = value.lower() |  | ||||||
|         for member in cls: |  | ||||||
|             if member.lower() == value: |  | ||||||
|                 return member |  | ||||||
|         return None |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class PatchRequest(BasePatchRequest): | class PatchRequest(BasePatchRequest): | ||||||
|     """PatchRequest which correctly sets schemas""" |     """PatchRequest which correctly sets schemas""" | ||||||
|  |  | ||||||
| @ -91,7 +74,6 @@ class PatchRequest(BasePatchRequest): | |||||||
| class PatchOperation(BasePatchOperation): | class PatchOperation(BasePatchOperation): | ||||||
|     """PatchOperation with optional path""" |     """PatchOperation with optional path""" | ||||||
|  |  | ||||||
|     op: PatchOp |  | ||||||
|     path: str | None |     path: str | None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -44,7 +44,6 @@ class TestRBACRoleAPI(APITestCase): | |||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             res.content.decode(), |             res.content.decode(), | ||||||
|             { |             { | ||||||
|                 "autocomplete": {}, |  | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                     "next": 0, |                     "next": 0, | ||||||
|                     "previous": 0, |                     "previous": 0, | ||||||
|  | |||||||
| @ -46,7 +46,6 @@ class TestRBACUserAPI(APITestCase): | |||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             res.content.decode(), |             res.content.decode(), | ||||||
|             { |             { | ||||||
|                 "autocomplete": {}, |  | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                     "next": 0, |                     "next": 0, | ||||||
|                     "previous": 0, |                     "previous": 0, | ||||||
|  | |||||||
| @ -38,7 +38,6 @@ class TestAPIPerms(APITestCase): | |||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             res.content.decode(), |             res.content.decode(), | ||||||
|             { |             { | ||||||
|                 "autocomplete": {}, |  | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                     "next": 0, |                     "next": 0, | ||||||
|                     "previous": 0, |                     "previous": 0, | ||||||
| @ -74,7 +73,6 @@ class TestAPIPerms(APITestCase): | |||||||
|         self.assertJSONEqual( |         self.assertJSONEqual( | ||||||
|             res.content.decode(), |             res.content.decode(), | ||||||
|             { |             { | ||||||
|                 "autocomplete": {}, |  | ||||||
|                 "pagination": { |                 "pagination": { | ||||||
|                     "next": 0, |                     "next": 0, | ||||||
|                     "previous": 0, |                     "previous": 0, | ||||||
|  | |||||||
| @ -9,14 +9,13 @@ https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/ | |||||||
|  |  | ||||||
| import django | import django | ||||||
| from channels.routing import ProtocolTypeRouter, URLRouter | from channels.routing import ProtocolTypeRouter, URLRouter | ||||||
|  | from defusedxml import defuse_stdlib | ||||||
| from django.core.asgi import get_asgi_application | from django.core.asgi import get_asgi_application | ||||||
| from sentry_sdk.integrations.asgi import SentryAsgiMiddleware | from sentry_sdk.integrations.asgi import SentryAsgiMiddleware | ||||||
|  |  | ||||||
| from authentik.root.setup import setup |  | ||||||
|  |  | ||||||
| # DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py | # DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py | ||||||
|  |  | ||||||
| setup() | defuse_stdlib() | ||||||
| django.setup() | django.setup() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -27,7 +27,7 @@ from structlog.stdlib import get_logger | |||||||
| from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp | from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp | ||||||
|  |  | ||||||
| from authentik import get_full_version | from authentik import get_full_version | ||||||
| from authentik.lib.sentry import should_ignore_exception | from authentik.lib.sentry import before_send | ||||||
| from authentik.lib.utils.errors import exception_to_string | from authentik.lib.utils.errors import exception_to_string | ||||||
|  |  | ||||||
| # set the default Django settings module for the 'celery' program. | # set the default Django settings module for the 'celery' program. | ||||||
| @ -81,7 +81,7 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar | |||||||
|  |  | ||||||
|     LOGGER.warning("Task failure", task_id=task_id.replace("-", ""), exc=exception) |     LOGGER.warning("Task failure", task_id=task_id.replace("-", ""), exc=exception) | ||||||
|     CTX_TASK_ID.set(...) |     CTX_TASK_ID.set(...) | ||||||
|     if not should_ignore_exception(exception): |     if before_send({}, {"exc_info": (None, exception, None)}) is not None: | ||||||
|         Event.new( |         Event.new( | ||||||
|             EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id |             EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id | ||||||
|         ).save() |         ).save() | ||||||
|  | |||||||
| @ -1,49 +1,13 @@ | |||||||
| """authentik database backend""" | """authentik database backend""" | ||||||
|  |  | ||||||
| from django.core.checks import Warning |  | ||||||
| from django.db.backends.base.validation import BaseDatabaseValidation |  | ||||||
| from django_tenants.postgresql_backend.base import DatabaseWrapper as BaseDatabaseWrapper | from django_tenants.postgresql_backend.base import DatabaseWrapper as BaseDatabaseWrapper | ||||||
|  |  | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
|  |  | ||||||
|  |  | ||||||
| class DatabaseValidation(BaseDatabaseValidation): |  | ||||||
|  |  | ||||||
|     def check(self, **kwargs): |  | ||||||
|         return self._check_encoding() |  | ||||||
|  |  | ||||||
|     def _check_encoding(self): |  | ||||||
|         """Throw a warning when the server_encoding is not UTF-8 or |  | ||||||
|         server_encoding and client_encoding are mismatched""" |  | ||||||
|         messages = [] |  | ||||||
|         with self.connection.cursor() as cursor: |  | ||||||
|             cursor.execute("SHOW server_encoding;") |  | ||||||
|             server_encoding = cursor.fetchone()[0] |  | ||||||
|             cursor.execute("SHOW client_encoding;") |  | ||||||
|             client_encoding = cursor.fetchone()[0] |  | ||||||
|             if server_encoding != client_encoding: |  | ||||||
|                 messages.append( |  | ||||||
|                     Warning( |  | ||||||
|                         "PostgreSQL Server and Client encoding are mismatched: Server: " |  | ||||||
|                         f"{server_encoding}, Client: {client_encoding}", |  | ||||||
|                         id="ak.db.W001", |  | ||||||
|                     ) |  | ||||||
|                 ) |  | ||||||
|             if server_encoding != "UTF8": |  | ||||||
|                 messages.append( |  | ||||||
|                     Warning( |  | ||||||
|                         f"PostgreSQL Server encoding is not UTF8: {server_encoding}", |  | ||||||
|                         id="ak.db.W002", |  | ||||||
|                     ) |  | ||||||
|                 ) |  | ||||||
|         return messages |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DatabaseWrapper(BaseDatabaseWrapper): | class DatabaseWrapper(BaseDatabaseWrapper): | ||||||
|     """database backend which supports rotating credentials""" |     """database backend which supports rotating credentials""" | ||||||
|  |  | ||||||
|     validation_class = DatabaseValidation |  | ||||||
|  |  | ||||||
|     def get_connection_params(self): |     def get_connection_params(self): | ||||||
|         """Refresh DB credentials before getting connection params""" |         """Refresh DB credentials before getting connection params""" | ||||||
|         conn_params = super().get_connection_params() |         conn_params = super().get_connection_params() | ||||||
|  | |||||||
| @ -446,8 +446,6 @@ _DISALLOWED_ITEMS = [ | |||||||
|     "MIDDLEWARE", |     "MIDDLEWARE", | ||||||
|     "AUTHENTICATION_BACKENDS", |     "AUTHENTICATION_BACKENDS", | ||||||
|     "CELERY", |     "CELERY", | ||||||
|     "SPECTACULAR_SETTINGS", |  | ||||||
|     "REST_FRAMEWORK", |  | ||||||
| ] | ] | ||||||
|  |  | ||||||
| SILENCED_SYSTEM_CHECKS = [ | SILENCED_SYSTEM_CHECKS = [ | ||||||
| @ -470,8 +468,6 @@ def _update_settings(app_path: str): | |||||||
|         TENANT_APPS.extend(getattr(settings_module, "TENANT_APPS", [])) |         TENANT_APPS.extend(getattr(settings_module, "TENANT_APPS", [])) | ||||||
|         MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", [])) |         MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", [])) | ||||||
|         AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", [])) |         AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", [])) | ||||||
|         SPECTACULAR_SETTINGS.update(getattr(settings_module, "SPECTACULAR_SETTINGS", {})) |  | ||||||
|         REST_FRAMEWORK.update(getattr(settings_module, "REST_FRAMEWORK", {})) |  | ||||||
|         CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {})) |         CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {})) | ||||||
|         for _attr in dir(settings_module): |         for _attr in dir(settings_module): | ||||||
|             if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS: |             if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS: | ||||||
|  | |||||||
| @ -1,26 +0,0 @@ | |||||||
| import os |  | ||||||
| import warnings |  | ||||||
|  |  | ||||||
| from cryptography.hazmat.backends.openssl.backend import backend |  | ||||||
| from defusedxml import defuse_stdlib |  | ||||||
|  |  | ||||||
| from authentik.lib.config import CONFIG |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def setup(): |  | ||||||
|     warnings.filterwarnings("ignore", "SelectableGroups dict interface") |  | ||||||
|     warnings.filterwarnings( |  | ||||||
|         "ignore", |  | ||||||
|         "defusedxml.lxml is no longer supported and will be removed in a future release.", |  | ||||||
|     ) |  | ||||||
|     warnings.filterwarnings( |  | ||||||
|         "ignore", |  | ||||||
|         "defusedxml.cElementTree is deprecated, import from defusedxml.ElementTree instead.", |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     defuse_stdlib() |  | ||||||
|  |  | ||||||
|     if CONFIG.get_bool("compliance.fips.enabled", False): |  | ||||||
|         backend._enable_fips() |  | ||||||
|  |  | ||||||
|     os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings") |  | ||||||
| @ -11,8 +11,6 @@ from django.contrib.contenttypes.models import ContentType | |||||||
| from django.test.runner import DiscoverRunner | from django.test.runner import DiscoverRunner | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR |  | ||||||
| from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR |  | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.lib.sentry import sentry_init | from authentik.lib.sentry import sentry_init | ||||||
| from authentik.root.signals import post_startup, pre_startup, startup | from authentik.root.signals import post_startup, pre_startup, startup | ||||||
| @ -78,9 +76,6 @@ class PytestTestRunner(DiscoverRunner):  # pragma: no cover | |||||||
|         for key, value in test_config.items(): |         for key, value in test_config.items(): | ||||||
|             CONFIG.set(key, value) |             CONFIG.set(key, value) | ||||||
|  |  | ||||||
|         ASN_CONTEXT_PROCESSOR.load() |  | ||||||
|         GEOIP_CONTEXT_PROCESSOR.load() |  | ||||||
|  |  | ||||||
|         sentry_init() |         sentry_init() | ||||||
|         self.logger.debug("Test environment configured") |         self.logger.debug("Test environment configured") | ||||||
|  |  | ||||||
|  | |||||||
| @ -71,31 +71,37 @@ def ldap_sync_single(source_pk: str): | |||||||
|             return |             return | ||||||
|         # Delete all sync tasks from the cache |         # Delete all sync tasks from the cache | ||||||
|         DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete() |         DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete() | ||||||
|  |         task = chain( | ||||||
|         # The order of these operations needs to be preserved as each depends on the previous one(s) |             # User and group sync can happen at once, they have no dependencies on each other | ||||||
|         # 1. User and group sync can happen simultaneously |             group( | ||||||
|         # 2. Membership sync needs to run afterwards |                 ldap_sync_paginator(source, UserLDAPSynchronizer) | ||||||
|         # 3. Finally, user and group deletions can happen simultaneously |                 + ldap_sync_paginator(source, GroupLDAPSynchronizer), | ||||||
|         user_group_sync = ldap_sync_paginator(source, UserLDAPSynchronizer) + ldap_sync_paginator( |             ), | ||||||
|             source, GroupLDAPSynchronizer |             # Membership sync needs to run afterwards | ||||||
|  |             group( | ||||||
|  |                 ldap_sync_paginator(source, MembershipLDAPSynchronizer), | ||||||
|  |             ), | ||||||
|  |             # Finally, deletions. What we'd really like to do here is something like | ||||||
|  |             # ``` | ||||||
|  |             # user_identifiers = <ldap query> | ||||||
|  |             # User.objects.exclude( | ||||||
|  |             #     usersourceconnection__identifier__in=user_uniqueness_identifiers, | ||||||
|  |             # ).delete() | ||||||
|  |             # ``` | ||||||
|  |             # This runs into performance issues in large installations. So instead we spread the | ||||||
|  |             # work out into three steps: | ||||||
|  |             # 1. Get every object from the LDAP source. | ||||||
|  |             # 2. Mark every object as "safe" in the database. This is quick, but any error could | ||||||
|  |             #    mean deleting users which should not be deleted, so we do it immediately, in | ||||||
|  |             #    large chunks, and only queue the deletion step afterwards. | ||||||
|  |             # 3. Delete every unmarked item. This is slow, so we spread it over many tasks in | ||||||
|  |             #    small chunks. | ||||||
|  |             group( | ||||||
|  |                 ldap_sync_paginator(source, UserLDAPForwardDeletion) | ||||||
|  |                 + ldap_sync_paginator(source, GroupLDAPForwardDeletion), | ||||||
|  |             ), | ||||||
|         ) |         ) | ||||||
|         membership_sync = ldap_sync_paginator(source, MembershipLDAPSynchronizer) |         task() | ||||||
|         user_group_deletion = ldap_sync_paginator( |  | ||||||
|             source, UserLDAPForwardDeletion |  | ||||||
|         ) + ldap_sync_paginator(source, GroupLDAPForwardDeletion) |  | ||||||
|  |  | ||||||
|         # Celery is buggy with empty groups, so we are careful only to add non-empty groups. |  | ||||||
|         # See https://github.com/celery/celery/issues/9772 |  | ||||||
|         task_groups = [] |  | ||||||
|         if user_group_sync: |  | ||||||
|             task_groups.append(group(user_group_sync)) |  | ||||||
|         if membership_sync: |  | ||||||
|             task_groups.append(group(membership_sync)) |  | ||||||
|         if user_group_deletion: |  | ||||||
|             task_groups.append(group(user_group_deletion)) |  | ||||||
|  |  | ||||||
|         all_tasks = chain(task_groups) |  | ||||||
|         all_tasks() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list: | def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list: | ||||||
|  | |||||||
| @ -1,277 +0,0 @@ | |||||||
| """Test SCIM Group""" |  | ||||||
|  |  | ||||||
| from json import dumps |  | ||||||
| from uuid import uuid4 |  | ||||||
|  |  | ||||||
| from django.urls import reverse |  | ||||||
| from rest_framework.test import APITestCase |  | ||||||
|  |  | ||||||
| from authentik.core.models import Group |  | ||||||
| from authentik.core.tests.utils import create_test_user |  | ||||||
| from authentik.events.models import Event, EventAction |  | ||||||
| from authentik.lib.generators import generate_id |  | ||||||
| from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema |  | ||||||
| from authentik.sources.scim.models import ( |  | ||||||
|     SCIMSource, |  | ||||||
|     SCIMSourceGroup, |  | ||||||
| ) |  | ||||||
| from authentik.sources.scim.views.v2.base import SCIM_CONTENT_TYPE |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestSCIMGroups(APITestCase): |  | ||||||
|     """Test SCIM Group view""" |  | ||||||
|  |  | ||||||
|     def setUp(self) -> None: |  | ||||||
|         self.source = SCIMSource.objects.create(name=generate_id(), slug=generate_id()) |  | ||||||
|  |  | ||||||
|     def test_group_list(self): |  | ||||||
|         """Test full group list""" |  | ||||||
|         response = self.client.get( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|  |  | ||||||
|     def test_group_list_single(self): |  | ||||||
|         """Test full group list (single group)""" |  | ||||||
|         group = Group.objects.create(name=generate_id()) |  | ||||||
|         user = create_test_user() |  | ||||||
|         group.users.add(user) |  | ||||||
|         SCIMSourceGroup.objects.create( |  | ||||||
|             source=self.source, |  | ||||||
|             group=group, |  | ||||||
|             id=str(uuid4()), |  | ||||||
|         ) |  | ||||||
|         response = self.client.get( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                     "group_id": str(group.pk), |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, second=200) |  | ||||||
|         SCIMGroupSchema.model_validate_json(response.content, strict=True) |  | ||||||
|  |  | ||||||
|     def test_group_create(self): |  | ||||||
|         """Test group create""" |  | ||||||
|         ext_id = generate_id() |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             data=dumps({"displayName": generate_id(), "externalId": ext_id}), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 201) |  | ||||||
|         self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists()) |  | ||||||
|         self.assertTrue( |  | ||||||
|             Event.objects.filter( |  | ||||||
|                 action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username |  | ||||||
|             ).exists() |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_group_create_members(self): |  | ||||||
|         """Test group create""" |  | ||||||
|         user = create_test_user() |  | ||||||
|         ext_id = generate_id() |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             data=dumps( |  | ||||||
|                 { |  | ||||||
|                     "displayName": generate_id(), |  | ||||||
|                     "externalId": ext_id, |  | ||||||
|                     "members": [{"value": str(user.uuid)}], |  | ||||||
|                 } |  | ||||||
|             ), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 201) |  | ||||||
|         self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists()) |  | ||||||
|         self.assertTrue( |  | ||||||
|             Event.objects.filter( |  | ||||||
|                 action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username |  | ||||||
|             ).exists() |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_group_create_members_empty(self): |  | ||||||
|         """Test group create""" |  | ||||||
|         ext_id = generate_id() |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             data=dumps({"displayName": generate_id(), "externalId": ext_id, "members": []}), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 201) |  | ||||||
|         self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists()) |  | ||||||
|         self.assertTrue( |  | ||||||
|             Event.objects.filter( |  | ||||||
|                 action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username |  | ||||||
|             ).exists() |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_group_create_duplicate(self): |  | ||||||
|         """Test group create (duplicate)""" |  | ||||||
|         group = Group.objects.create(name=generate_id()) |  | ||||||
|         existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) |  | ||||||
|         ext_id = generate_id() |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             data=dumps( |  | ||||||
|                 {"displayName": generate_id(), "externalId": ext_id, "id": str(existing.group.pk)} |  | ||||||
|             ), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 409) |  | ||||||
|         self.assertJSONEqual( |  | ||||||
|             response.content, |  | ||||||
|             { |  | ||||||
|                 "detail": "Group with ID exists already.", |  | ||||||
|                 "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], |  | ||||||
|                 "scimType": "uniqueness", |  | ||||||
|                 "status": 409, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_group_update(self): |  | ||||||
|         """Test group update""" |  | ||||||
|         group = Group.objects.create(name=generate_id()) |  | ||||||
|         existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) |  | ||||||
|         ext_id = generate_id() |  | ||||||
|         response = self.client.put( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={"source_slug": self.source.slug, "group_id": group.pk}, |  | ||||||
|             ), |  | ||||||
|             data=dumps( |  | ||||||
|                 {"displayName": generate_id(), "externalId": ext_id, "id": str(existing.pk)} |  | ||||||
|             ), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, second=200) |  | ||||||
|  |  | ||||||
|     def test_group_update_non_existent(self): |  | ||||||
|         """Test group update""" |  | ||||||
|         ext_id = generate_id() |  | ||||||
|         response = self.client.put( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                     "group_id": str(uuid4()), |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             data=dumps({"displayName": generate_id(), "externalId": ext_id, "id": ""}), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, second=404) |  | ||||||
|         self.assertJSONEqual( |  | ||||||
|             response.content, |  | ||||||
|             { |  | ||||||
|                 "detail": "Group not found.", |  | ||||||
|                 "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], |  | ||||||
|                 "status": 404, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     def test_group_patch_add(self): |  | ||||||
|         """Test group patch""" |  | ||||||
|         user = create_test_user() |  | ||||||
|  |  | ||||||
|         group = Group.objects.create(name=generate_id()) |  | ||||||
|         SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) |  | ||||||
|         response = self.client.patch( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={"source_slug": self.source.slug, "group_id": group.pk}, |  | ||||||
|             ), |  | ||||||
|             data=dumps( |  | ||||||
|                 { |  | ||||||
|                     "Operations": [ |  | ||||||
|                         { |  | ||||||
|                             "op": "Add", |  | ||||||
|                             "path": "members", |  | ||||||
|                             "value": {"value": str(user.uuid)}, |  | ||||||
|                         } |  | ||||||
|                     ] |  | ||||||
|                 } |  | ||||||
|             ), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, second=200) |  | ||||||
|         self.assertTrue(group.users.filter(pk=user.pk).exists()) |  | ||||||
|  |  | ||||||
|     def test_group_patch_remove(self): |  | ||||||
|         """Test group patch""" |  | ||||||
|         user = create_test_user() |  | ||||||
|  |  | ||||||
|         group = Group.objects.create(name=generate_id()) |  | ||||||
|         group.users.add(user) |  | ||||||
|         SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) |  | ||||||
|         response = self.client.patch( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={"source_slug": self.source.slug, "group_id": group.pk}, |  | ||||||
|             ), |  | ||||||
|             data=dumps( |  | ||||||
|                 { |  | ||||||
|                     "Operations": [ |  | ||||||
|                         { |  | ||||||
|                             "op": "remove", |  | ||||||
|                             "path": "members", |  | ||||||
|                             "value": {"value": str(user.uuid)}, |  | ||||||
|                         } |  | ||||||
|                     ] |  | ||||||
|                 } |  | ||||||
|             ), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, second=200) |  | ||||||
|         self.assertFalse(group.users.filter(pk=user.pk).exists()) |  | ||||||
|  |  | ||||||
|     def test_group_delete(self): |  | ||||||
|         """Test group delete""" |  | ||||||
|         group = Group.objects.create(name=generate_id()) |  | ||||||
|         SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) |  | ||||||
|         response = self.client.delete( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-groups", |  | ||||||
|                 kwargs={"source_slug": self.source.slug, "group_id": group.pk}, |  | ||||||
|             ), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, second=204) |  | ||||||
| @ -177,51 +177,3 @@ class TestSCIMUsers(APITestCase): | |||||||
|             SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"], |             SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"], | ||||||
|             "0123456789", |             "0123456789", | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def test_user_update(self): |  | ||||||
|         """Test user update""" |  | ||||||
|         user = create_test_user() |  | ||||||
|         existing = SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4()) |  | ||||||
|         ext_id = generate_id() |  | ||||||
|         response = self.client.put( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-users", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                     "user_id": str(user.uuid), |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             data=dumps( |  | ||||||
|                 { |  | ||||||
|                     "id": str(existing.pk), |  | ||||||
|                     "userName": generate_id(), |  | ||||||
|                     "externalId": ext_id, |  | ||||||
|                     "emails": [ |  | ||||||
|                         { |  | ||||||
|                             "primary": True, |  | ||||||
|                             "value": user.email, |  | ||||||
|                         } |  | ||||||
|                     ], |  | ||||||
|                 } |  | ||||||
|             ), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|  |  | ||||||
|     def test_user_delete(self): |  | ||||||
|         """Test user delete""" |  | ||||||
|         user = create_test_user() |  | ||||||
|         SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4()) |  | ||||||
|         response = self.client.delete( |  | ||||||
|             reverse( |  | ||||||
|                 "authentik_sources_scim:v2-users", |  | ||||||
|                 kwargs={ |  | ||||||
|                     "source_slug": self.source.slug, |  | ||||||
|                     "user_id": str(user.uuid), |  | ||||||
|                 }, |  | ||||||
|             ), |  | ||||||
|             content_type=SCIM_CONTENT_TYPE, |  | ||||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 204) |  | ||||||
|  | |||||||
| @ -8,7 +8,6 @@ from rest_framework.authentication import BaseAuthentication, get_authorization_ | |||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| from rest_framework.views import APIView | from rest_framework.views import APIView | ||||||
|  |  | ||||||
| from authentik.core.middleware import CTX_AUTH_VIA |  | ||||||
| from authentik.core.models import Token, TokenIntents, User | from authentik.core.models import Token, TokenIntents, User | ||||||
| from authentik.sources.scim.models import SCIMSource | from authentik.sources.scim.models import SCIMSource | ||||||
|  |  | ||||||
| @ -27,7 +26,6 @@ class SCIMTokenAuth(BaseAuthentication): | |||||||
|         _username, _, password = b64decode(key.encode()).decode().partition(":") |         _username, _, password = b64decode(key.encode()).decode().partition(":") | ||||||
|         token = self.check_token(password, source_slug) |         token = self.check_token(password, source_slug) | ||||||
|         if token: |         if token: | ||||||
|             CTX_AUTH_VIA.set("scim_basic") |  | ||||||
|             return (token.user, token) |             return (token.user, token) | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
| @ -54,5 +52,4 @@ class SCIMTokenAuth(BaseAuthentication): | |||||||
|         token = self.check_token(key, source_slug) |         token = self.check_token(key, source_slug) | ||||||
|         if not token: |         if not token: | ||||||
|             return None |             return None | ||||||
|         CTX_AUTH_VIA.set("scim_token") |  | ||||||
|         return (token.user, token) |         return (token.user, token) | ||||||
|  | |||||||
| @ -1,11 +1,13 @@ | |||||||
| """SCIM Utils""" | """SCIM Utils""" | ||||||
|  |  | ||||||
| from typing import Any | from typing import Any | ||||||
|  | from urllib.parse import urlparse | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.core.paginator import Page, Paginator | from django.core.paginator import Page, Paginator | ||||||
| from django.db.models import Q, QuerySet | from django.db.models import Q, QuerySet | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
|  | from django.urls import resolve | ||||||
| from rest_framework.parsers import JSONParser | from rest_framework.parsers import JSONParser | ||||||
| from rest_framework.permissions import IsAuthenticated | from rest_framework.permissions import IsAuthenticated | ||||||
| from rest_framework.renderers import JSONRenderer | from rest_framework.renderers import JSONRenderer | ||||||
| @ -44,7 +46,7 @@ class SCIMView(APIView): | |||||||
|     logger: BoundLogger |     logger: BoundLogger | ||||||
|  |  | ||||||
|     permission_classes = [IsAuthenticated] |     permission_classes = [IsAuthenticated] | ||||||
|     parser_classes = [SCIMParser, JSONParser] |     parser_classes = [SCIMParser] | ||||||
|     renderer_classes = [SCIMRenderer] |     renderer_classes = [SCIMRenderer] | ||||||
|  |  | ||||||
|     def setup(self, request: HttpRequest, *args: Any, **kwargs: Any) -> None: |     def setup(self, request: HttpRequest, *args: Any, **kwargs: Any) -> None: | ||||||
| @ -54,6 +56,28 @@ class SCIMView(APIView): | |||||||
|     def get_authenticators(self): |     def get_authenticators(self): | ||||||
|         return [SCIMTokenAuth(self)] |         return [SCIMTokenAuth(self)] | ||||||
|  |  | ||||||
|  |     def patch_resolve_value(self, raw_value: dict) -> User | Group | None: | ||||||
|  |         """Attempt to resolve a raw `value` attribute of a patch operation into | ||||||
|  |         a database model""" | ||||||
|  |         model = User | ||||||
|  |         query = {} | ||||||
|  |         if "$ref" in raw_value: | ||||||
|  |             url = urlparse(raw_value["$ref"]) | ||||||
|  |             if match := resolve(url.path): | ||||||
|  |                 if match.url_name == "v2-users": | ||||||
|  |                     model = User | ||||||
|  |                     query = {"pk": int(match.kwargs["user_id"])} | ||||||
|  |         elif "type" in raw_value: | ||||||
|  |             match raw_value["type"]: | ||||||
|  |                 case "User": | ||||||
|  |                     model = User | ||||||
|  |                     query = {"pk": int(raw_value["value"])} | ||||||
|  |                 case "Group": | ||||||
|  |                     model = Group | ||||||
|  |         else: | ||||||
|  |             return None | ||||||
|  |         return model.objects.filter(**query).first() | ||||||
|  |  | ||||||
|     def filter_parse(self, request: Request): |     def filter_parse(self, request: Request): | ||||||
|         """Parse the path of a Patch Operation""" |         """Parse the path of a Patch Operation""" | ||||||
|         path = request.query_params.get("filter") |         path = request.query_params.get("filter") | ||||||
|  | |||||||
| @ -1,58 +0,0 @@ | |||||||
| from enum import Enum |  | ||||||
|  |  | ||||||
| from pydanticscim.responses import SCIMError as BaseSCIMError |  | ||||||
| from rest_framework.exceptions import ValidationError |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SCIMErrorTypes(Enum): |  | ||||||
|     invalid_filter = "invalidFilter" |  | ||||||
|     too_many = "tooMany" |  | ||||||
|     uniqueness = "uniqueness" |  | ||||||
|     mutability = "mutability" |  | ||||||
|     invalid_syntax = "invalidSyntax" |  | ||||||
|     invalid_path = "invalidPath" |  | ||||||
|     no_target = "noTarget" |  | ||||||
|     invalid_value = "invalidValue" |  | ||||||
|     invalid_vers = "invalidVers" |  | ||||||
|     sensitive = "sensitive" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SCIMError(BaseSCIMError): |  | ||||||
|     scimType: SCIMErrorTypes | None = None |  | ||||||
|     detail: str | None = None |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SCIMValidationError(ValidationError): |  | ||||||
|     status_code = 400 |  | ||||||
|     default_detail = SCIMError(scimType=SCIMErrorTypes.invalid_syntax, status=400) |  | ||||||
|  |  | ||||||
|     def __init__(self, detail: SCIMError | None): |  | ||||||
|         if detail is None: |  | ||||||
|             detail = self.default_detail |  | ||||||
|         detail.status = self.status_code |  | ||||||
|         self.detail = detail.model_dump(mode="json", exclude_none=True) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SCIMConflictError(SCIMValidationError): |  | ||||||
|     status_code = 409 |  | ||||||
|  |  | ||||||
|     def __init__(self, detail: str): |  | ||||||
|         super().__init__( |  | ||||||
|             SCIMError( |  | ||||||
|                 detail=detail, |  | ||||||
|                 scimType=SCIMErrorTypes.uniqueness, |  | ||||||
|                 status=self.status_code, |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SCIMNotFoundError(SCIMValidationError): |  | ||||||
|     status_code = 404 |  | ||||||
|  |  | ||||||
|     def __init__(self, detail: str): |  | ||||||
|         super().__init__( |  | ||||||
|             SCIMError( |  | ||||||
|                 detail=detail, |  | ||||||
|                 status=self.status_code, |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
| @ -4,25 +4,19 @@ from uuid import uuid4 | |||||||
|  |  | ||||||
| from django.db.models import Q | from django.db.models import Q | ||||||
| from django.db.transaction import atomic | from django.db.transaction import atomic | ||||||
| from django.http import QueryDict | from django.http import Http404, QueryDict | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from pydantic import ValidationError as PydanticValidationError | from pydantic import ValidationError as PydanticValidationError | ||||||
| from pydanticscim.group import GroupMember | from pydanticscim.group import GroupMember | ||||||
| from rest_framework.exceptions import ValidationError | from rest_framework.exceptions import ValidationError | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
| from scim2_filter_parser.attr_paths import AttrPath |  | ||||||
|  |  | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOp, PatchOperation | from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA | ||||||
| from authentik.providers.scim.clients.schema import Group as SCIMGroupModel | from authentik.providers.scim.clients.schema import Group as SCIMGroupModel | ||||||
| from authentik.sources.scim.models import SCIMSourceGroup | from authentik.sources.scim.models import SCIMSourceGroup | ||||||
| from authentik.sources.scim.views.v2.base import SCIMObjectView | from authentik.sources.scim.views.v2.base import SCIMObjectView | ||||||
| from authentik.sources.scim.views.v2.exceptions import ( |  | ||||||
|     SCIMConflictError, |  | ||||||
|     SCIMNotFoundError, |  | ||||||
|     SCIMValidationError, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class GroupsView(SCIMObjectView): | class GroupsView(SCIMObjectView): | ||||||
| @ -33,7 +27,7 @@ class GroupsView(SCIMObjectView): | |||||||
|     def group_to_scim(self, scim_group: SCIMSourceGroup) -> dict: |     def group_to_scim(self, scim_group: SCIMSourceGroup) -> dict: | ||||||
|         """Convert Group to SCIM data""" |         """Convert Group to SCIM data""" | ||||||
|         payload = SCIMGroupModel( |         payload = SCIMGroupModel( | ||||||
|             schemas=[SCIM_GROUP_SCHEMA], |             schemas=[SCIM_USER_SCHEMA], | ||||||
|             id=str(scim_group.group.pk), |             id=str(scim_group.group.pk), | ||||||
|             externalId=scim_group.id, |             externalId=scim_group.id, | ||||||
|             displayName=scim_group.group.name, |             displayName=scim_group.group.name, | ||||||
| @ -64,7 +58,7 @@ class GroupsView(SCIMObjectView): | |||||||
|         if group_id: |         if group_id: | ||||||
|             connection = base_query.filter(source=self.source, group__group_uuid=group_id).first() |             connection = base_query.filter(source=self.source, group__group_uuid=group_id).first() | ||||||
|             if not connection: |             if not connection: | ||||||
|                 raise SCIMNotFoundError("Group not found.") |                 raise Http404 | ||||||
|             return Response(self.group_to_scim(connection)) |             return Response(self.group_to_scim(connection)) | ||||||
|         connections = ( |         connections = ( | ||||||
|             base_query.filter(source=self.source).order_by("pk").filter(self.filter_parse(request)) |             base_query.filter(source=self.source).order_by("pk").filter(self.filter_parse(request)) | ||||||
| @ -125,7 +119,7 @@ class GroupsView(SCIMObjectView): | |||||||
|         ).first() |         ).first() | ||||||
|         if connection: |         if connection: | ||||||
|             self.logger.debug("Found existing group") |             self.logger.debug("Found existing group") | ||||||
|             raise SCIMConflictError("Group with ID exists already.") |             return Response(status=409) | ||||||
|         connection = self.update_group(None, request.data) |         connection = self.update_group(None, request.data) | ||||||
|         return Response(self.group_to_scim(connection), status=201) |         return Response(self.group_to_scim(connection), status=201) | ||||||
|  |  | ||||||
| @ -135,44 +129,10 @@ class GroupsView(SCIMObjectView): | |||||||
|             source=self.source, group__group_uuid=group_id |             source=self.source, group__group_uuid=group_id | ||||||
|         ).first() |         ).first() | ||||||
|         if not connection: |         if not connection: | ||||||
|             raise SCIMNotFoundError("Group not found.") |             raise Http404 | ||||||
|         connection = self.update_group(connection, request.data) |         connection = self.update_group(connection, request.data) | ||||||
|         return Response(self.group_to_scim(connection), status=200) |         return Response(self.group_to_scim(connection), status=200) | ||||||
|  |  | ||||||
|     @atomic |  | ||||||
|     def patch(self, request: Request, group_id: str, **kwargs) -> Response: |  | ||||||
|         """Patch group handler""" |  | ||||||
|         connection = SCIMSourceGroup.objects.filter( |  | ||||||
|             source=self.source, group__group_uuid=group_id |  | ||||||
|         ).first() |  | ||||||
|         if not connection: |  | ||||||
|             raise SCIMNotFoundError("Group not found.") |  | ||||||
|  |  | ||||||
|         for _op in request.data.get("Operations", []): |  | ||||||
|             operation = PatchOperation.model_validate(_op) |  | ||||||
|             if operation.op.lower() not in ["add", "remove", "replace"]: |  | ||||||
|                 raise SCIMValidationError() |  | ||||||
|             attr_path = AttrPath(f'{operation.path} eq ""', {}) |  | ||||||
|             if attr_path.first_path == ("members", None, None): |  | ||||||
|                 # FIXME: this can probably be de-duplicated |  | ||||||
|                 if operation.op == PatchOp.add: |  | ||||||
|                     if not isinstance(operation.value, list): |  | ||||||
|                         operation.value = [operation.value] |  | ||||||
|                     query = Q() |  | ||||||
|                     for member in operation.value: |  | ||||||
|                         query |= Q(uuid=member["value"]) |  | ||||||
|                     if query: |  | ||||||
|                         connection.group.users.add(*User.objects.filter(query)) |  | ||||||
|                 elif operation.op == PatchOp.remove: |  | ||||||
|                     if not isinstance(operation.value, list): |  | ||||||
|                         operation.value = [operation.value] |  | ||||||
|                     query = Q() |  | ||||||
|                     for member in operation.value: |  | ||||||
|                         query |= Q(uuid=member["value"]) |  | ||||||
|                     if query: |  | ||||||
|                         connection.group.users.remove(*User.objects.filter(query)) |  | ||||||
|         return Response(self.group_to_scim(connection), status=200) |  | ||||||
|  |  | ||||||
|     @atomic |     @atomic | ||||||
|     def delete(self, request: Request, group_id: str, **kwargs) -> Response: |     def delete(self, request: Request, group_id: str, **kwargs) -> Response: | ||||||
|         """Delete group handler""" |         """Delete group handler""" | ||||||
| @ -180,7 +140,7 @@ class GroupsView(SCIMObjectView): | |||||||
|             source=self.source, group__group_uuid=group_id |             source=self.source, group__group_uuid=group_id | ||||||
|         ).first() |         ).first() | ||||||
|         if not connection: |         if not connection: | ||||||
|             raise SCIMNotFoundError("Group not found.") |             raise Http404 | ||||||
|         connection.group.delete() |         connection.group.delete() | ||||||
|         connection.delete() |         connection.delete() | ||||||
|         return Response(status=204) |         return Response(status=204) | ||||||
|  | |||||||
| @ -1,11 +1,11 @@ | |||||||
| """SCIM Meta views""" | """SCIM Meta views""" | ||||||
|  |  | ||||||
|  | from django.http import Http404 | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
|  |  | ||||||
| from authentik.sources.scim.views.v2.base import SCIMView | from authentik.sources.scim.views.v2.base import SCIMView | ||||||
| from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ResourceTypesView(SCIMView): | class ResourceTypesView(SCIMView): | ||||||
| @ -138,7 +138,7 @@ class ResourceTypesView(SCIMView): | |||||||
|             resource = [x for x in resource_types if x.get("id") == resource_type] |             resource = [x for x in resource_types if x.get("id") == resource_type] | ||||||
|             if resource: |             if resource: | ||||||
|                 return Response(resource[0]) |                 return Response(resource[0]) | ||||||
|             raise SCIMNotFoundError("Resource not found.") |             raise Http404 | ||||||
|         return Response( |         return Response( | ||||||
|             { |             { | ||||||
|                 "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], |                 "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], | ||||||
|  | |||||||
| @ -3,12 +3,12 @@ | |||||||
| from json import loads | from json import loads | ||||||
|  |  | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
|  | from django.http import Http404 | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
|  |  | ||||||
| from authentik.sources.scim.views.v2.base import SCIMView | from authentik.sources.scim.views.v2.base import SCIMView | ||||||
| from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError |  | ||||||
|  |  | ||||||
| with open( | with open( | ||||||
|     settings.BASE_DIR / "authentik" / "sources" / "scim" / "schemas" / "schema.json", |     settings.BASE_DIR / "authentik" / "sources" / "scim" / "schemas" / "schema.json", | ||||||
| @ -44,7 +44,7 @@ class SchemaView(SCIMView): | |||||||
|             schema = [x for x in schemas if x.get("id") == schema_uri] |             schema = [x for x in schemas if x.get("id") == schema_uri] | ||||||
|             if schema: |             if schema: | ||||||
|                 return Response(schema[0]) |                 return Response(schema[0]) | ||||||
|             raise SCIMNotFoundError("Schema not found.") |             raise Http404 | ||||||
|         return Response( |         return Response( | ||||||
|             { |             { | ||||||
|                 "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], |                 "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], | ||||||
|  | |||||||
| @ -33,8 +33,6 @@ class ServiceProviderConfigView(SCIMView): | |||||||
|             { |             { | ||||||
|                 "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"], |                 "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"], | ||||||
|                 "authenticationSchemes": auth_schemas, |                 "authenticationSchemes": auth_schemas, | ||||||
|                 # We only support patch for groups currently, so don't broadly advertise it. |  | ||||||
|                 # Implementations that require Group patch will use it regardless of this flag. |  | ||||||
|                 "patch": {"supported": False}, |                 "patch": {"supported": False}, | ||||||
|                 "bulk": {"supported": False, "maxOperations": 0, "maxPayloadSize": 0}, |                 "bulk": {"supported": False, "maxOperations": 0, "maxPayloadSize": 0}, | ||||||
|                 "filter": { |                 "filter": { | ||||||
|  | |||||||
| @ -4,7 +4,7 @@ from uuid import uuid4 | |||||||
|  |  | ||||||
| from django.db.models import Q | from django.db.models import Q | ||||||
| from django.db.transaction import atomic | from django.db.transaction import atomic | ||||||
| from django.http import QueryDict | from django.http import Http404, QueryDict | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from pydanticscim.user import Email, EmailKind, Name | from pydanticscim.user import Email, EmailKind, Name | ||||||
| from rest_framework.exceptions import ValidationError | from rest_framework.exceptions import ValidationError | ||||||
| @ -16,7 +16,6 @@ from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA | |||||||
| from authentik.providers.scim.clients.schema import User as SCIMUserModel | from authentik.providers.scim.clients.schema import User as SCIMUserModel | ||||||
| from authentik.sources.scim.models import SCIMSourceUser | from authentik.sources.scim.models import SCIMSourceUser | ||||||
| from authentik.sources.scim.views.v2.base import SCIMObjectView | from authentik.sources.scim.views.v2.base import SCIMObjectView | ||||||
| from authentik.sources.scim.views.v2.exceptions import SCIMConflictError, SCIMNotFoundError |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class UsersView(SCIMObjectView): | class UsersView(SCIMObjectView): | ||||||
| @ -70,7 +69,7 @@ class UsersView(SCIMObjectView): | |||||||
|                 .first() |                 .first() | ||||||
|             ) |             ) | ||||||
|             if not connection: |             if not connection: | ||||||
|                 raise SCIMNotFoundError("User not found.") |                 raise Http404 | ||||||
|             return Response(self.user_to_scim(connection)) |             return Response(self.user_to_scim(connection)) | ||||||
|         connections = ( |         connections = ( | ||||||
|             SCIMSourceUser.objects.filter(source=self.source).select_related("user").order_by("pk") |             SCIMSourceUser.objects.filter(source=self.source).select_related("user").order_by("pk") | ||||||
| @ -123,7 +122,7 @@ class UsersView(SCIMObjectView): | |||||||
|         ).first() |         ).first() | ||||||
|         if connection: |         if connection: | ||||||
|             self.logger.debug("Found existing user") |             self.logger.debug("Found existing user") | ||||||
|             raise SCIMConflictError("Group with ID exists already.") |             return Response(status=409) | ||||||
|         connection = self.update_user(None, request.data) |         connection = self.update_user(None, request.data) | ||||||
|         return Response(self.user_to_scim(connection), status=201) |         return Response(self.user_to_scim(connection), status=201) | ||||||
|  |  | ||||||
| @ -131,7 +130,7 @@ class UsersView(SCIMObjectView): | |||||||
|         """Update user handler""" |         """Update user handler""" | ||||||
|         connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first() |         connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first() | ||||||
|         if not connection: |         if not connection: | ||||||
|             raise SCIMNotFoundError("User not found.") |             raise Http404 | ||||||
|         self.update_user(connection, request.data) |         self.update_user(connection, request.data) | ||||||
|         return Response(self.user_to_scim(connection), status=200) |         return Response(self.user_to_scim(connection), status=200) | ||||||
|  |  | ||||||
| @ -140,7 +139,7 @@ class UsersView(SCIMObjectView): | |||||||
|         """Delete user handler""" |         """Delete user handler""" | ||||||
|         connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first() |         connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first() | ||||||
|         if not connection: |         if not connection: | ||||||
|             raise SCIMNotFoundError("User not found.") |             raise Http404 | ||||||
|         connection.user.delete() |         connection.user.delete() | ||||||
|         connection.delete() |         connection.delete() | ||||||
|         return Response(status=204) |         return Response(status=204) | ||||||
|  | |||||||
| @ -1,7 +1,6 @@ | |||||||
| """Validation stage challenge checking""" | """Validation stage challenge checking""" | ||||||
|  |  | ||||||
| from json import loads | from json import loads | ||||||
| from typing import TYPE_CHECKING |  | ||||||
| from urllib.parse import urlencode | from urllib.parse import urlencode | ||||||
|  |  | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| @ -37,12 +36,10 @@ from authentik.stages.authenticator_email.models import EmailDevice | |||||||
| from authentik.stages.authenticator_sms.models import SMSDevice | from authentik.stages.authenticator_sms.models import SMSDevice | ||||||
| from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses | from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses | ||||||
| from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice | from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice | ||||||
| from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE | from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE | ||||||
| from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id | from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| if TYPE_CHECKING: |  | ||||||
|     from authentik.stages.authenticator_validate.stage import AuthenticatorValidateStageView |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DeviceChallenge(PassiveSerializer): | class DeviceChallenge(PassiveSerializer): | ||||||
| @ -55,11 +52,11 @@ class DeviceChallenge(PassiveSerializer): | |||||||
|  |  | ||||||
|  |  | ||||||
| def get_challenge_for_device( | def get_challenge_for_device( | ||||||
|     stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage, device: Device |     request: HttpRequest, stage: AuthenticatorValidateStage, device: Device | ||||||
| ) -> dict: | ) -> dict: | ||||||
|     """Generate challenge for a single device""" |     """Generate challenge for a single device""" | ||||||
|     if isinstance(device, WebAuthnDevice): |     if isinstance(device, WebAuthnDevice): | ||||||
|         return get_webauthn_challenge(stage_view, stage, device) |         return get_webauthn_challenge(request, stage, device) | ||||||
|     if isinstance(device, EmailDevice): |     if isinstance(device, EmailDevice): | ||||||
|         return {"email": mask_email(device.email)} |         return {"email": mask_email(device.email)} | ||||||
|     # Code-based challenges have no hints |     # Code-based challenges have no hints | ||||||
| @ -67,30 +64,26 @@ def get_challenge_for_device( | |||||||
|  |  | ||||||
|  |  | ||||||
| def get_webauthn_challenge_without_user( | def get_webauthn_challenge_without_user( | ||||||
|     stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage |     request: HttpRequest, stage: AuthenticatorValidateStage | ||||||
| ) -> dict: | ) -> dict: | ||||||
|     """Same as `get_webauthn_challenge`, but allows any client device. We can then later check |     """Same as `get_webauthn_challenge`, but allows any client device. We can then later check | ||||||
|     who the device belongs to.""" |     who the device belongs to.""" | ||||||
|     stage_view.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None) |     request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None) | ||||||
|     authentication_options = generate_authentication_options( |     authentication_options = generate_authentication_options( | ||||||
|         rp_id=get_rp_id(stage_view.request), |         rp_id=get_rp_id(request), | ||||||
|         allow_credentials=[], |         allow_credentials=[], | ||||||
|         user_verification=UserVerificationRequirement(stage.webauthn_user_verification), |         user_verification=UserVerificationRequirement(stage.webauthn_user_verification), | ||||||
|     ) |     ) | ||||||
|     stage_view.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = ( |     request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge | ||||||
|         authentication_options.challenge |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     return loads(options_to_json(authentication_options)) |     return loads(options_to_json(authentication_options)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_webauthn_challenge( | def get_webauthn_challenge( | ||||||
|     stage_view: "AuthenticatorValidateStageView", |     request: HttpRequest, stage: AuthenticatorValidateStage, device: WebAuthnDevice | None = None | ||||||
|     stage: AuthenticatorValidateStage, |  | ||||||
|     device: WebAuthnDevice | None = None, |  | ||||||
| ) -> dict: | ) -> dict: | ||||||
|     """Send the client a challenge that we'll check later""" |     """Send the client a challenge that we'll check later""" | ||||||
|     stage_view.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None) |     request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None) | ||||||
|  |  | ||||||
|     allowed_credentials = [] |     allowed_credentials = [] | ||||||
|  |  | ||||||
| @ -101,14 +94,12 @@ def get_webauthn_challenge( | |||||||
|             allowed_credentials.append(user_device.descriptor) |             allowed_credentials.append(user_device.descriptor) | ||||||
|  |  | ||||||
|     authentication_options = generate_authentication_options( |     authentication_options = generate_authentication_options( | ||||||
|         rp_id=get_rp_id(stage_view.request), |         rp_id=get_rp_id(request), | ||||||
|         allow_credentials=allowed_credentials, |         allow_credentials=allowed_credentials, | ||||||
|         user_verification=UserVerificationRequirement(stage.webauthn_user_verification), |         user_verification=UserVerificationRequirement(stage.webauthn_user_verification), | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     stage_view.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = ( |     request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge | ||||||
|         authentication_options.challenge |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     return loads(options_to_json(authentication_options)) |     return loads(options_to_json(authentication_options)) | ||||||
|  |  | ||||||
| @ -155,7 +146,7 @@ def validate_challenge_code(code: str, stage_view: StageView, user: User) -> Dev | |||||||
| def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -> Device: | def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -> Device: | ||||||
|     """Validate WebAuthn Challenge""" |     """Validate WebAuthn Challenge""" | ||||||
|     request = stage_view.request |     request = stage_view.request | ||||||
|     challenge = stage_view.executor.plan.context.get(PLAN_CONTEXT_WEBAUTHN_CHALLENGE) |     challenge = request.session.get(SESSION_KEY_WEBAUTHN_CHALLENGE) | ||||||
|     stage: AuthenticatorValidateStage = stage_view.executor.current_stage |     stage: AuthenticatorValidateStage = stage_view.executor.current_stage | ||||||
|     try: |     try: | ||||||
|         credential = parse_authentication_credential_json(data) |         credential = parse_authentication_credential_json(data) | ||||||
|  | |||||||
| @ -224,7 +224,7 @@ class AuthenticatorValidateStageView(ChallengeStageView): | |||||||
|                 data={ |                 data={ | ||||||
|                     "device_class": device_class, |                     "device_class": device_class, | ||||||
|                     "device_uid": device.pk, |                     "device_uid": device.pk, | ||||||
|                     "challenge": get_challenge_for_device(self, stage, device), |                     "challenge": get_challenge_for_device(self.request, stage, device), | ||||||
|                     "last_used": device.last_used, |                     "last_used": device.last_used, | ||||||
|                 } |                 } | ||||||
|             ) |             ) | ||||||
| @ -243,7 +243,7 @@ class AuthenticatorValidateStageView(ChallengeStageView): | |||||||
|                 "device_class": DeviceClasses.WEBAUTHN, |                 "device_class": DeviceClasses.WEBAUTHN, | ||||||
|                 "device_uid": -1, |                 "device_uid": -1, | ||||||
|                 "challenge": get_webauthn_challenge_without_user( |                 "challenge": get_webauthn_challenge_without_user( | ||||||
|                     self, |                     self.request, | ||||||
|                     self.executor.current_stage, |                     self.executor.current_stage, | ||||||
|                 ), |                 ), | ||||||
|                 "last_used": None, |                 "last_used": None, | ||||||
|  | |||||||
| @ -31,7 +31,7 @@ from authentik.stages.authenticator_webauthn.models import ( | |||||||
|     WebAuthnDevice, |     WebAuthnDevice, | ||||||
|     WebAuthnDeviceType, |     WebAuthnDeviceType, | ||||||
| ) | ) | ||||||
| from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE | from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE | ||||||
| from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import | from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import | ||||||
| from authentik.stages.identification.models import IdentificationStage, UserFields | from authentik.stages.identification.models import IdentificationStage, UserFields | ||||||
| from authentik.stages.user_login.models import UserLoginStage | from authentik.stages.user_login.models import UserLoginStage | ||||||
| @ -103,11 +103,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): | |||||||
|             device_classes=[DeviceClasses.WEBAUTHN], |             device_classes=[DeviceClasses.WEBAUTHN], | ||||||
|             webauthn_user_verification=UserVerification.PREFERRED, |             webauthn_user_verification=UserVerification.PREFERRED, | ||||||
|         ) |         ) | ||||||
|         plan = FlowPlan("") |         challenge = get_challenge_for_device(request, stage, webauthn_device) | ||||||
|         stage_view = AuthenticatorValidateStageView( |  | ||||||
|             FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request |  | ||||||
|         ) |  | ||||||
|         challenge = get_challenge_for_device(stage_view, stage, webauthn_device) |  | ||||||
|         del challenge["challenge"] |         del challenge["challenge"] | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             challenge, |             challenge, | ||||||
| @ -126,9 +122,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): | |||||||
|  |  | ||||||
|         with self.assertRaises(ValidationError): |         with self.assertRaises(ValidationError): | ||||||
|             validate_challenge_webauthn( |             validate_challenge_webauthn( | ||||||
|                 {}, |                 {}, StageView(FlowExecutorView(current_stage=stage), request=request), self.user | ||||||
|                 StageView(FlowExecutorView(current_stage=stage, plan=plan), request=request), |  | ||||||
|                 self.user, |  | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def test_device_challenge_webauthn_restricted(self): |     def test_device_challenge_webauthn_restricted(self): | ||||||
| @ -199,35 +193,22 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): | |||||||
|             sign_count=0, |             sign_count=0, | ||||||
|             rp_id=generate_id(), |             rp_id=generate_id(), | ||||||
|         ) |         ) | ||||||
|         plan = FlowPlan("") |         challenge = get_challenge_for_device(request, stage, webauthn_device) | ||||||
|         plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes( |         webauthn_challenge = request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] | ||||||
|             "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA" |  | ||||||
|         ) |  | ||||||
|         stage_view = AuthenticatorValidateStageView( |  | ||||||
|             FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request |  | ||||||
|         ) |  | ||||||
|         challenge = get_challenge_for_device(stage_view, stage, webauthn_device) |  | ||||||
|         self.assertEqual( |         self.assertEqual( | ||||||
|             challenge["allowCredentials"], |             challenge, | ||||||
|             [ |             { | ||||||
|                 { |                 "allowCredentials": [ | ||||||
|                     "id": "QKZ97ASJAOIDyipAs6mKUxDUZgDrWrbAsUb5leL7-oU", |                     { | ||||||
|                     "type": "public-key", |                         "id": "QKZ97ASJAOIDyipAs6mKUxDUZgDrWrbAsUb5leL7-oU", | ||||||
|                 } |                         "type": "public-key", | ||||||
|             ], |                     } | ||||||
|         ) |                 ], | ||||||
|         self.assertIsNotNone(challenge["challenge"]) |                 "challenge": bytes_to_base64url(webauthn_challenge), | ||||||
|         self.assertEqual( |                 "rpId": "testserver", | ||||||
|             challenge["rpId"], |                 "timeout": 60000, | ||||||
|             "testserver", |                 "userVerification": "preferred", | ||||||
|         ) |             }, | ||||||
|         self.assertEqual( |  | ||||||
|             challenge["timeout"], |  | ||||||
|             60000, |  | ||||||
|         ) |  | ||||||
|         self.assertEqual( |  | ||||||
|             challenge["userVerification"], |  | ||||||
|             "preferred", |  | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def test_get_challenge_userless(self): |     def test_get_challenge_userless(self): | ||||||
| @ -247,16 +228,18 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): | |||||||
|             sign_count=0, |             sign_count=0, | ||||||
|             rp_id=generate_id(), |             rp_id=generate_id(), | ||||||
|         ) |         ) | ||||||
|         plan = FlowPlan("") |         challenge = get_webauthn_challenge_without_user(request, stage) | ||||||
|         stage_view = AuthenticatorValidateStageView( |         webauthn_challenge = request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] | ||||||
|             FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request |         self.assertEqual( | ||||||
|  |             challenge, | ||||||
|  |             { | ||||||
|  |                 "allowCredentials": [], | ||||||
|  |                 "challenge": bytes_to_base64url(webauthn_challenge), | ||||||
|  |                 "rpId": "testserver", | ||||||
|  |                 "timeout": 60000, | ||||||
|  |                 "userVerification": "preferred", | ||||||
|  |             }, | ||||||
|         ) |         ) | ||||||
|         challenge = get_webauthn_challenge_without_user(stage_view, stage) |  | ||||||
|         self.assertEqual(challenge["allowCredentials"], []) |  | ||||||
|         self.assertIsNotNone(challenge["challenge"]) |  | ||||||
|         self.assertEqual(challenge["rpId"], "testserver") |  | ||||||
|         self.assertEqual(challenge["timeout"], 60000) |  | ||||||
|         self.assertEqual(challenge["userVerification"], "preferred") |  | ||||||
|  |  | ||||||
|     def test_validate_challenge_unrestricted(self): |     def test_validate_challenge_unrestricted(self): | ||||||
|         """Test webauthn authentication (unrestricted webauthn device)""" |         """Test webauthn authentication (unrestricted webauthn device)""" | ||||||
| @ -292,10 +275,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): | |||||||
|                 "last_used": None, |                 "last_used": None, | ||||||
|             } |             } | ||||||
|         ] |         ] | ||||||
|         plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes( |         session[SESSION_KEY_PLAN] = plan | ||||||
|  |         session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes( | ||||||
|             "aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ" |             "aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ" | ||||||
|         ) |         ) | ||||||
|         session[SESSION_KEY_PLAN] = plan |  | ||||||
|         session.save() |         session.save() | ||||||
|  |  | ||||||
|         response = self.client.post( |         response = self.client.post( | ||||||
| @ -369,10 +352,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): | |||||||
|                 "last_used": None, |                 "last_used": None, | ||||||
|             } |             } | ||||||
|         ] |         ] | ||||||
|         plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes( |         session[SESSION_KEY_PLAN] = plan | ||||||
|  |         session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes( | ||||||
|             "aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ" |             "aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ" | ||||||
|         ) |         ) | ||||||
|         session[SESSION_KEY_PLAN] = plan |  | ||||||
|         session.save() |         session.save() | ||||||
|  |  | ||||||
|         response = self.client.post( |         response = self.client.post( | ||||||
| @ -450,10 +433,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): | |||||||
|                 "last_used": None, |                 "last_used": None, | ||||||
|             } |             } | ||||||
|         ] |         ] | ||||||
|         plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes( |         session[SESSION_KEY_PLAN] = plan | ||||||
|  |         session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes( | ||||||
|             "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA" |             "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA" | ||||||
|         ) |         ) | ||||||
|         session[SESSION_KEY_PLAN] = plan |  | ||||||
|         session.save() |         session.save() | ||||||
|  |  | ||||||
|         response = self.client.post( |         response = self.client.post( | ||||||
| @ -513,14 +496,17 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): | |||||||
|             not_configured_action=NotConfiguredAction.CONFIGURE, |             not_configured_action=NotConfiguredAction.CONFIGURE, | ||||||
|             device_classes=[DeviceClasses.WEBAUTHN], |             device_classes=[DeviceClasses.WEBAUTHN], | ||||||
|         ) |         ) | ||||||
|         plan = FlowPlan(flow.pk.hex) |         stage_view = AuthenticatorValidateStageView( | ||||||
|         plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes( |             FlowExecutorView(flow=flow, current_stage=stage), request=request | ||||||
|             "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA" |  | ||||||
|         ) |         ) | ||||||
|         request = get_request("/") |         request = get_request("/") | ||||||
|  |         request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes( | ||||||
|  |             "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA" | ||||||
|  |         ) | ||||||
|  |         request.session.save() | ||||||
|  |  | ||||||
|         stage_view = AuthenticatorValidateStageView( |         stage_view = AuthenticatorValidateStageView( | ||||||
|             FlowExecutorView(flow=flow, current_stage=stage, plan=plan), request=request |             FlowExecutorView(flow=flow, current_stage=stage), request=request | ||||||
|         ) |         ) | ||||||
|         request.META["SERVER_NAME"] = "localhost" |         request.META["SERVER_NAME"] = "localhost" | ||||||
|         request.META["SERVER_PORT"] = "9000" |         request.META["SERVER_PORT"] = "9000" | ||||||
|  | |||||||
| @ -25,7 +25,6 @@ class AuthenticatorWebAuthnStageSerializer(StageSerializer): | |||||||
|             "resident_key_requirement", |             "resident_key_requirement", | ||||||
|             "device_type_restrictions", |             "device_type_restrictions", | ||||||
|             "device_type_restrictions_obj", |             "device_type_restrictions_obj", | ||||||
|             "max_attempts", |  | ||||||
|         ] |         ] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @ -1,21 +0,0 @@ | |||||||
| # Generated by Django 5.1.11 on 2025-06-13 22:41 |  | ||||||
|  |  | ||||||
| from django.db import migrations, models |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Migration(migrations.Migration): |  | ||||||
|  |  | ||||||
|     dependencies = [ |  | ||||||
|         ( |  | ||||||
|             "authentik_stages_authenticator_webauthn", |  | ||||||
|             "0012_webauthndevice_created_webauthndevice_last_updated_and_more", |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     operations = [ |  | ||||||
|         migrations.AddField( |  | ||||||
|             model_name="authenticatorwebauthnstage", |  | ||||||
|             name="max_attempts", |  | ||||||
|             field=models.PositiveIntegerField(default=0), |  | ||||||
|         ), |  | ||||||
|     ] |  | ||||||
| @ -84,8 +84,6 @@ class AuthenticatorWebAuthnStage(ConfigurableStage, FriendlyNamedStage, Stage): | |||||||
|  |  | ||||||
|     device_type_restrictions = models.ManyToManyField("WebAuthnDeviceType", blank=True) |     device_type_restrictions = models.ManyToManyField("WebAuthnDeviceType", blank=True) | ||||||
|  |  | ||||||
|     max_attempts = models.PositiveIntegerField(default=0) |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> type[BaseSerializer]: |     def serializer(self) -> type[BaseSerializer]: | ||||||
|         from authentik.stages.authenticator_webauthn.api.stages import ( |         from authentik.stages.authenticator_webauthn.api.stages import ( | ||||||
|  | |||||||
| @ -5,13 +5,12 @@ from uuid import UUID | |||||||
|  |  | ||||||
| from django.http import HttpRequest, HttpResponse | from django.http import HttpRequest, HttpResponse | ||||||
| from django.http.request import QueryDict | from django.http.request import QueryDict | ||||||
| from django.utils.translation import gettext as __ |  | ||||||
| from django.utils.translation import gettext_lazy as _ | from django.utils.translation import gettext_lazy as _ | ||||||
| from rest_framework.fields import CharField | from rest_framework.fields import CharField | ||||||
| from rest_framework.serializers import ValidationError | from rest_framework.serializers import ValidationError | ||||||
| from webauthn import options_to_json | from webauthn import options_to_json | ||||||
| from webauthn.helpers.bytes_to_base64url import bytes_to_base64url | from webauthn.helpers.bytes_to_base64url import bytes_to_base64url | ||||||
| from webauthn.helpers.exceptions import WebAuthnException | from webauthn.helpers.exceptions import InvalidRegistrationResponse | ||||||
| from webauthn.helpers.structs import ( | from webauthn.helpers.structs import ( | ||||||
|     AttestationConveyancePreference, |     AttestationConveyancePreference, | ||||||
|     AuthenticatorAttachment, |     AuthenticatorAttachment, | ||||||
| @ -42,8 +41,7 @@ from authentik.stages.authenticator_webauthn.models import ( | |||||||
| ) | ) | ||||||
| from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id | from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id | ||||||
|  |  | ||||||
| PLAN_CONTEXT_WEBAUTHN_CHALLENGE = "goauthentik.io/stages/authenticator_webauthn/challenge" | SESSION_KEY_WEBAUTHN_CHALLENGE = "authentik/stages/authenticator_webauthn/challenge" | ||||||
| PLAN_CONTEXT_WEBAUTHN_ATTEMPT = "goauthentik.io/stages/authenticator_webauthn/attempt" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class AuthenticatorWebAuthnChallenge(WithUserInfoChallenge): | class AuthenticatorWebAuthnChallenge(WithUserInfoChallenge): | ||||||
| @ -64,7 +62,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse): | |||||||
|  |  | ||||||
|     def validate_response(self, response: dict) -> dict: |     def validate_response(self, response: dict) -> dict: | ||||||
|         """Validate webauthn challenge response""" |         """Validate webauthn challenge response""" | ||||||
|         challenge = self.stage.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] |         challenge = self.request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             registration: VerifiedRegistration = verify_registration_response( |             registration: VerifiedRegistration = verify_registration_response( | ||||||
| @ -73,7 +71,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse): | |||||||
|                 expected_rp_id=get_rp_id(self.request), |                 expected_rp_id=get_rp_id(self.request), | ||||||
|                 expected_origin=get_origin(self.request), |                 expected_origin=get_origin(self.request), | ||||||
|             ) |             ) | ||||||
|         except WebAuthnException as exc: |         except InvalidRegistrationResponse as exc: | ||||||
|             self.stage.logger.warning("registration failed", exc=exc) |             self.stage.logger.warning("registration failed", exc=exc) | ||||||
|             raise ValidationError(f"Registration failed. Error: {exc}") from None |             raise ValidationError(f"Registration failed. Error: {exc}") from None | ||||||
|  |  | ||||||
| @ -116,10 +114,9 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView): | |||||||
|     response_class = AuthenticatorWebAuthnChallengeResponse |     response_class = AuthenticatorWebAuthnChallengeResponse | ||||||
|  |  | ||||||
|     def get_challenge(self, *args, **kwargs) -> Challenge: |     def get_challenge(self, *args, **kwargs) -> Challenge: | ||||||
|  |         # clear session variables prior to starting a new registration | ||||||
|  |         self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None) | ||||||
|         stage: AuthenticatorWebAuthnStage = self.executor.current_stage |         stage: AuthenticatorWebAuthnStage = self.executor.current_stage | ||||||
|         self.executor.plan.context.setdefault(PLAN_CONTEXT_WEBAUTHN_ATTEMPT, 0) |  | ||||||
|         # clear flow variables prior to starting a new registration |  | ||||||
|         self.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None) |  | ||||||
|         user = self.get_pending_user() |         user = self.get_pending_user() | ||||||
|  |  | ||||||
|         # library accepts none so we store null in the database, but if there is a value |         # library accepts none so we store null in the database, but if there is a value | ||||||
| @ -142,7 +139,8 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView): | |||||||
|             attestation=AttestationConveyancePreference.DIRECT, |             attestation=AttestationConveyancePreference.DIRECT, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = registration_options.challenge |         self.request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = registration_options.challenge | ||||||
|  |         self.request.session.save() | ||||||
|         return AuthenticatorWebAuthnChallenge( |         return AuthenticatorWebAuthnChallenge( | ||||||
|             data={ |             data={ | ||||||
|                 "registration": loads(options_to_json(registration_options)), |                 "registration": loads(options_to_json(registration_options)), | ||||||
| @ -155,24 +153,6 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView): | |||||||
|         response.user = self.get_pending_user() |         response.user = self.get_pending_user() | ||||||
|         return response |         return response | ||||||
|  |  | ||||||
|     def challenge_invalid(self, response): |  | ||||||
|         stage: AuthenticatorWebAuthnStage = self.executor.current_stage |  | ||||||
|         self.executor.plan.context.setdefault(PLAN_CONTEXT_WEBAUTHN_ATTEMPT, 0) |  | ||||||
|         self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_ATTEMPT] += 1 |  | ||||||
|         if ( |  | ||||||
|             stage.max_attempts > 0 |  | ||||||
|             and self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_ATTEMPT] >= stage.max_attempts |  | ||||||
|         ): |  | ||||||
|             return self.executor.stage_invalid( |  | ||||||
|                 __( |  | ||||||
|                     "Exceeded maximum attempts. " |  | ||||||
|                     "Contact your {brand} administrator for help.".format( |  | ||||||
|                         brand=self.request.brand.branding_title |  | ||||||
|                     ) |  | ||||||
|                 ) |  | ||||||
|             ) |  | ||||||
|         return super().challenge_invalid(response) |  | ||||||
|  |  | ||||||
|     def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: |     def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: | ||||||
|         # Webauthn Challenge has already been validated |         # Webauthn Challenge has already been validated | ||||||
|         webauthn_credential: VerifiedRegistration = response.validated_data["response"] |         webauthn_credential: VerifiedRegistration = response.validated_data["response"] | ||||||
| @ -199,3 +179,6 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView): | |||||||
|         else: |         else: | ||||||
|             return self.executor.stage_invalid("Device with Credential ID already exists.") |             return self.executor.stage_invalid("Device with Credential ID already exists.") | ||||||
|         return self.executor.stage_ok() |         return self.executor.stage_ok() | ||||||
|  |  | ||||||
|  |     def cleanup(self): | ||||||
|  |         self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None) | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ from authentik.stages.authenticator_webauthn.models import ( | |||||||
|     WebAuthnDevice, |     WebAuthnDevice, | ||||||
|     WebAuthnDeviceType, |     WebAuthnDeviceType, | ||||||
| ) | ) | ||||||
| from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE | from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE | ||||||
| from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import | from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -57,9 +57,6 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase): | |||||||
|         response = self.client.get( |         response = self.client.get( | ||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         plan: FlowPlan = self.client.session[SESSION_KEY_PLAN] |  | ||||||
|  |  | ||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
|         session = self.client.session |         session = self.client.session | ||||||
|         self.assertStageResponse( |         self.assertStageResponse( | ||||||
| @ -73,7 +70,7 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase): | |||||||
|                     "name": self.user.username, |                     "name": self.user.username, | ||||||
|                     "displayName": self.user.name, |                     "displayName": self.user.name, | ||||||
|                 }, |                 }, | ||||||
|                 "challenge": bytes_to_base64url(plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE]), |                 "challenge": bytes_to_base64url(session[SESSION_KEY_WEBAUTHN_CHALLENGE]), | ||||||
|                 "pubKeyCredParams": [ |                 "pubKeyCredParams": [ | ||||||
|                     {"type": "public-key", "alg": -7}, |                     {"type": "public-key", "alg": -7}, | ||||||
|                     {"type": "public-key", "alg": -8}, |                     {"type": "public-key", "alg": -8}, | ||||||
| @ -100,11 +97,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase): | |||||||
|         """Test registration""" |         """Test registration""" | ||||||
|         plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) |         plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) | ||||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = self.user |         plan.context[PLAN_CONTEXT_PENDING_USER] = self.user | ||||||
|         plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode( |  | ||||||
|             b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw==" |  | ||||||
|         ) |  | ||||||
|         session = self.client.session |         session = self.client.session | ||||||
|         session[SESSION_KEY_PLAN] = plan |         session[SESSION_KEY_PLAN] = plan | ||||||
|  |         session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode( | ||||||
|  |             b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw==" | ||||||
|  |         ) | ||||||
|         session.save() |         session.save() | ||||||
|         response = self.client.post( |         response = self.client.post( | ||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), | ||||||
| @ -149,11 +146,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase): | |||||||
|  |  | ||||||
|         plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) |         plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) | ||||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = self.user |         plan.context[PLAN_CONTEXT_PENDING_USER] = self.user | ||||||
|         plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode( |  | ||||||
|             b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw==" |  | ||||||
|         ) |  | ||||||
|         session = self.client.session |         session = self.client.session | ||||||
|         session[SESSION_KEY_PLAN] = plan |         session[SESSION_KEY_PLAN] = plan | ||||||
|  |         session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode( | ||||||
|  |             b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw==" | ||||||
|  |         ) | ||||||
|         session.save() |         session.save() | ||||||
|         response = self.client.post( |         response = self.client.post( | ||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), | ||||||
| @ -212,11 +209,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase): | |||||||
|  |  | ||||||
|         plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) |         plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) | ||||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = self.user |         plan.context[PLAN_CONTEXT_PENDING_USER] = self.user | ||||||
|         plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode( |  | ||||||
|             b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw==" |  | ||||||
|         ) |  | ||||||
|         session = self.client.session |         session = self.client.session | ||||||
|         session[SESSION_KEY_PLAN] = plan |         session[SESSION_KEY_PLAN] = plan | ||||||
|  |         session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode( | ||||||
|  |             b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw==" | ||||||
|  |         ) | ||||||
|         session.save() |         session.save() | ||||||
|         response = self.client.post( |         response = self.client.post( | ||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), | ||||||
| @ -262,11 +259,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase): | |||||||
|  |  | ||||||
|         plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) |         plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) | ||||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = self.user |         plan.context[PLAN_CONTEXT_PENDING_USER] = self.user | ||||||
|         plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode( |  | ||||||
|             b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw==" |  | ||||||
|         ) |  | ||||||
|         session = self.client.session |         session = self.client.session | ||||||
|         session[SESSION_KEY_PLAN] = plan |         session[SESSION_KEY_PLAN] = plan | ||||||
|  |         session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode( | ||||||
|  |             b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw==" | ||||||
|  |         ) | ||||||
|         session.save() |         session.save() | ||||||
|         response = self.client.post( |         response = self.client.post( | ||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), | ||||||
| @ -301,109 +298,3 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase): | |||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
|         self.assertStageRedirects(response, reverse("authentik_core:root-redirect")) |         self.assertStageRedirects(response, reverse("authentik_core:root-redirect")) | ||||||
|         self.assertTrue(WebAuthnDevice.objects.filter(user=self.user).exists()) |         self.assertTrue(WebAuthnDevice.objects.filter(user=self.user).exists()) | ||||||
|  |  | ||||||
|     def test_register_max_retries(self): |  | ||||||
|         """Test registration (exceeding max retries)""" |  | ||||||
|         self.stage.max_attempts = 2 |  | ||||||
|         self.stage.save() |  | ||||||
|  |  | ||||||
|         plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) |  | ||||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = self.user |  | ||||||
|         plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode( |  | ||||||
|             b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw==" |  | ||||||
|         ) |  | ||||||
|         session = self.client.session |  | ||||||
|         session[SESSION_KEY_PLAN] = plan |  | ||||||
|         session.save() |  | ||||||
|  |  | ||||||
|         # first failed request |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|             data={ |  | ||||||
|                 "component": "ak-stage-authenticator-webauthn", |  | ||||||
|                 "response": { |  | ||||||
|                     "id": "kqnmrVLnDG-OwsSNHkihYZaNz5s", |  | ||||||
|                     "rawId": "kqnmrVLnDG-OwsSNHkihYZaNz5s", |  | ||||||
|                     "type": "public-key", |  | ||||||
|                     "registrationClientExtensions": "{}", |  | ||||||
|                     "response": { |  | ||||||
|                         "clientDataJSON": ( |  | ||||||
|                             "eyJ0eXBlIjoid2ViYXV0aG4uY3JlYXRlIiwiY2hhbGxlbmd" |  | ||||||
|                             "lIjoiMDNYb2RpNTRnS3NmblA1STlWRmZoYUdYVlZFMk5VeV" |  | ||||||
|                             "pwQkJYbnNfSkkteDZWOVJZMlR3MlFteFJKa2hoNzE3NEVrU" |  | ||||||
|                             "mF6VW50SXdqTVZZOWJGRzYwTHciLCJvcmlnaW4iOiJodHRw" |  | ||||||
|                             "Oi8vbG9jYWxob3N0OjkwMDAiLCJjcm9zc09yaWdpbiI6ZmF" |  | ||||||
|                         ), |  | ||||||
|                         "attestationObject": ( |  | ||||||
|                             "o2NmbXRkbm9uZWdhdHRTdG10oGhhdXRoRGF0YViYSZYN5Yg" |  | ||||||
|                             "OjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2NdAAAAAPv8MA" |  | ||||||
|                             "cVTk7MjAtuAgVX170AFJKp5q1S5wxvjsLEjR5IoWGWjc-bp" |  | ||||||
|                             "QECAyYgASFYIKtcZHPumH37XHs0IM1v3pUBRIqHVV_SE-Lq" |  | ||||||
|                             "2zpJAOVXIlgg74Fg_WdB0kuLYqCKbxogkEPaVtR_iR3IyQFIJAXBzds" |  | ||||||
|                         ), |  | ||||||
|                     }, |  | ||||||
|                 }, |  | ||||||
|             }, |  | ||||||
|             SERVER_NAME="localhost", |  | ||||||
|             SERVER_PORT="9000", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|         self.assertStageResponse( |  | ||||||
|             response, |  | ||||||
|             flow=self.flow, |  | ||||||
|             component="ak-stage-authenticator-webauthn", |  | ||||||
|             response_errors={ |  | ||||||
|                 "response": [ |  | ||||||
|                     { |  | ||||||
|                         "string": ( |  | ||||||
|                             "Registration failed. Error: Unable to decode " |  | ||||||
|                             "client_data_json bytes as JSON" |  | ||||||
|                         ), |  | ||||||
|                         "code": "invalid", |  | ||||||
|                     } |  | ||||||
|                 ] |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         self.assertFalse(WebAuthnDevice.objects.filter(user=self.user).exists()) |  | ||||||
|  |  | ||||||
|         # Second failed request |  | ||||||
|         response = self.client.post( |  | ||||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), |  | ||||||
|             data={ |  | ||||||
|                 "component": "ak-stage-authenticator-webauthn", |  | ||||||
|                 "response": { |  | ||||||
|                     "id": "kqnmrVLnDG-OwsSNHkihYZaNz5s", |  | ||||||
|                     "rawId": "kqnmrVLnDG-OwsSNHkihYZaNz5s", |  | ||||||
|                     "type": "public-key", |  | ||||||
|                     "registrationClientExtensions": "{}", |  | ||||||
|                     "response": { |  | ||||||
|                         "clientDataJSON": ( |  | ||||||
|                             "eyJ0eXBlIjoid2ViYXV0aG4uY3JlYXRlIiwiY2hhbGxlbmd" |  | ||||||
|                             "lIjoiMDNYb2RpNTRnS3NmblA1STlWRmZoYUdYVlZFMk5VeV" |  | ||||||
|                             "pwQkJYbnNfSkkteDZWOVJZMlR3MlFteFJKa2hoNzE3NEVrU" |  | ||||||
|                             "mF6VW50SXdqTVZZOWJGRzYwTHciLCJvcmlnaW4iOiJodHRw" |  | ||||||
|                             "Oi8vbG9jYWxob3N0OjkwMDAiLCJjcm9zc09yaWdpbiI6ZmF" |  | ||||||
|                         ), |  | ||||||
|                         "attestationObject": ( |  | ||||||
|                             "o2NmbXRkbm9uZWdhdHRTdG10oGhhdXRoRGF0YViYSZYN5Yg" |  | ||||||
|                             "OjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2NdAAAAAPv8MA" |  | ||||||
|                             "cVTk7MjAtuAgVX170AFJKp5q1S5wxvjsLEjR5IoWGWjc-bp" |  | ||||||
|                             "QECAyYgASFYIKtcZHPumH37XHs0IM1v3pUBRIqHVV_SE-Lq" |  | ||||||
|                             "2zpJAOVXIlgg74Fg_WdB0kuLYqCKbxogkEPaVtR_iR3IyQFIJAXBzds" |  | ||||||
|                         ), |  | ||||||
|                     }, |  | ||||||
|                 }, |  | ||||||
|             }, |  | ||||||
|             SERVER_NAME="localhost", |  | ||||||
|             SERVER_PORT="9000", |  | ||||||
|         ) |  | ||||||
|         self.assertEqual(response.status_code, 200) |  | ||||||
|         self.assertStageResponse( |  | ||||||
|             response, |  | ||||||
|             flow=self.flow, |  | ||||||
|             component="ak-stage-access-denied", |  | ||||||
|             error_message=( |  | ||||||
|                 "Exceeded maximum attempts. Contact your authentik administrator for help." |  | ||||||
|             ), |  | ||||||
|         ) |  | ||||||
|         self.assertFalse(WebAuthnDevice.objects.filter(user=self.user).exists()) |  | ||||||
|  | |||||||
| @ -101,9 +101,9 @@ class BoundSessionMiddleware(SessionMiddleware): | |||||||
|             SESSION_KEY_BINDING_GEO, GeoIPBinding.NO_BINDING |             SESSION_KEY_BINDING_GEO, GeoIPBinding.NO_BINDING | ||||||
|         ) |         ) | ||||||
|         if configured_binding_net != NetworkBinding.NO_BINDING: |         if configured_binding_net != NetworkBinding.NO_BINDING: | ||||||
|             BoundSessionMiddleware.recheck_session_net(configured_binding_net, last_ip, new_ip) |             self.recheck_session_net(configured_binding_net, last_ip, new_ip) | ||||||
|         if configured_binding_geo != GeoIPBinding.NO_BINDING: |         if configured_binding_geo != GeoIPBinding.NO_BINDING: | ||||||
|             BoundSessionMiddleware.recheck_session_geo(configured_binding_geo, last_ip, new_ip) |             self.recheck_session_geo(configured_binding_geo, last_ip, new_ip) | ||||||
|         # If we got to this point without any error being raised, we need to |         # If we got to this point without any error being raised, we need to | ||||||
|         # update the last saved IP to the current one |         # update the last saved IP to the current one | ||||||
|         if SESSION_KEY_BINDING_NET in request.session or SESSION_KEY_BINDING_GEO in request.session: |         if SESSION_KEY_BINDING_NET in request.session or SESSION_KEY_BINDING_GEO in request.session: | ||||||
| @ -111,8 +111,7 @@ class BoundSessionMiddleware(SessionMiddleware): | |||||||
|             # (== basically requires the user to be logged in) |             # (== basically requires the user to be logged in) | ||||||
|             request.session[request.session.model.Keys.LAST_IP] = new_ip |             request.session[request.session.model.Keys.LAST_IP] = new_ip | ||||||
|  |  | ||||||
|     @staticmethod |     def recheck_session_net(self, binding: NetworkBinding, last_ip: str, new_ip: str): | ||||||
|     def recheck_session_net(binding: NetworkBinding, last_ip: str, new_ip: str): |  | ||||||
|         """Check network/ASN binding""" |         """Check network/ASN binding""" | ||||||
|         last_asn = ASN_CONTEXT_PROCESSOR.asn(last_ip) |         last_asn = ASN_CONTEXT_PROCESSOR.asn(last_ip) | ||||||
|         new_asn = ASN_CONTEXT_PROCESSOR.asn(new_ip) |         new_asn = ASN_CONTEXT_PROCESSOR.asn(new_ip) | ||||||
| @ -159,8 +158,7 @@ class BoundSessionMiddleware(SessionMiddleware): | |||||||
|                     new_ip, |                     new_ip, | ||||||
|                 ) |                 ) | ||||||
|  |  | ||||||
|     @staticmethod |     def recheck_session_geo(self, binding: GeoIPBinding, last_ip: str, new_ip: str): | ||||||
|     def recheck_session_geo(binding: GeoIPBinding, last_ip: str, new_ip: str): |  | ||||||
|         """Check GeoIP binding""" |         """Check GeoIP binding""" | ||||||
|         last_geo = GEOIP_CONTEXT_PROCESSOR.city(last_ip) |         last_geo = GEOIP_CONTEXT_PROCESSOR.city(last_ip) | ||||||
|         new_geo = GEOIP_CONTEXT_PROCESSOR.city(new_ip) |         new_geo = GEOIP_CONTEXT_PROCESSOR.city(new_ip) | ||||||
| @ -181,8 +179,8 @@ class BoundSessionMiddleware(SessionMiddleware): | |||||||
|             if last_geo.continent != new_geo.continent: |             if last_geo.continent != new_geo.continent: | ||||||
|                 raise SessionBindingBroken( |                 raise SessionBindingBroken( | ||||||
|                     "geoip.continent", |                     "geoip.continent", | ||||||
|                     last_geo.continent.to_dict(), |                     last_geo.continent, | ||||||
|                     new_geo.continent.to_dict(), |                     new_geo.continent, | ||||||
|                     last_ip, |                     last_ip, | ||||||
|                     new_ip, |                     new_ip, | ||||||
|                 ) |                 ) | ||||||
| @ -194,8 +192,8 @@ class BoundSessionMiddleware(SessionMiddleware): | |||||||
|             if last_geo.country != new_geo.country: |             if last_geo.country != new_geo.country: | ||||||
|                 raise SessionBindingBroken( |                 raise SessionBindingBroken( | ||||||
|                     "geoip.country", |                     "geoip.country", | ||||||
|                     last_geo.country.to_dict(), |                     last_geo.country, | ||||||
|                     new_geo.country.to_dict(), |                     new_geo.country, | ||||||
|                     last_ip, |                     last_ip, | ||||||
|                     new_ip, |                     new_ip, | ||||||
|                 ) |                 ) | ||||||
| @ -204,8 +202,8 @@ class BoundSessionMiddleware(SessionMiddleware): | |||||||
|             if last_geo.city != new_geo.city: |             if last_geo.city != new_geo.city: | ||||||
|                 raise SessionBindingBroken( |                 raise SessionBindingBroken( | ||||||
|                     "geoip.city", |                     "geoip.city", | ||||||
|                     last_geo.city.to_dict(), |                     last_geo.city, | ||||||
|                     new_geo.city.to_dict(), |                     new_geo.city, | ||||||
|                     last_ip, |                     last_ip, | ||||||
|                     new_ip, |                     new_ip, | ||||||
|                 ) |                 ) | ||||||
|  | |||||||
| @ -3,7 +3,6 @@ | |||||||
| from time import sleep | from time import sleep | ||||||
| from unittest.mock import patch | from unittest.mock import patch | ||||||
|  |  | ||||||
| from django.http import HttpRequest |  | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
|  |  | ||||||
| @ -18,12 +17,7 @@ from authentik.flows.views.executor import SESSION_KEY_PLAN | |||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.lib.utils.time import timedelta_from_string | from authentik.lib.utils.time import timedelta_from_string | ||||||
| from authentik.root.middleware import ClientIPMiddleware | from authentik.root.middleware import ClientIPMiddleware | ||||||
| from authentik.stages.user_login.middleware import ( | from authentik.stages.user_login.models import UserLoginStage | ||||||
|     BoundSessionMiddleware, |  | ||||||
|     SessionBindingBroken, |  | ||||||
|     logout_extra, |  | ||||||
| ) |  | ||||||
| from authentik.stages.user_login.models import GeoIPBinding, NetworkBinding, UserLoginStage |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestUserLoginStage(FlowTestCase): | class TestUserLoginStage(FlowTestCase): | ||||||
| @ -198,52 +192,3 @@ class TestUserLoginStage(FlowTestCase): | |||||||
|         self.assertStageRedirects(response, reverse("authentik_core:root-redirect")) |         self.assertStageRedirects(response, reverse("authentik_core:root-redirect")) | ||||||
|         response = self.client.get(reverse("authentik_api:application-list")) |         response = self.client.get(reverse("authentik_api:application-list")) | ||||||
|         self.assertEqual(response.status_code, 403) |         self.assertEqual(response.status_code, 403) | ||||||
|  |  | ||||||
|     def test_binding_net_break_log(self): |  | ||||||
|         """Test logout_extra with exception""" |  | ||||||
|         # IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-ASN-Test.json |  | ||||||
|         for args, expect in [ |  | ||||||
|             [[NetworkBinding.BIND_ASN, "8.8.8.8", "8.8.8.8"], ["network.missing"]], |  | ||||||
|             [[NetworkBinding.BIND_ASN, "1.0.0.1", "1.128.0.1"], ["network.asn"]], |  | ||||||
|             [ |  | ||||||
|                 [NetworkBinding.BIND_ASN_NETWORK, "12.81.96.1", "12.81.128.1"], |  | ||||||
|                 ["network.asn_network"], |  | ||||||
|             ], |  | ||||||
|             [[NetworkBinding.BIND_ASN_NETWORK_IP, "1.0.0.1", "1.0.0.2"], ["network.ip"]], |  | ||||||
|         ]: |  | ||||||
|             with self.subTest(args[0]): |  | ||||||
|                 with self.assertRaises(SessionBindingBroken) as cm: |  | ||||||
|                     BoundSessionMiddleware.recheck_session_net(*args) |  | ||||||
|                 self.assertEqual(cm.exception.reason, expect[0]) |  | ||||||
|                 # Ensure the request can be logged without throwing errors |  | ||||||
|                 self.client.force_login(self.user) |  | ||||||
|                 request = HttpRequest() |  | ||||||
|                 request.session = self.client.session |  | ||||||
|                 request.user = self.user |  | ||||||
|                 logout_extra(request, cm.exception) |  | ||||||
|  |  | ||||||
|     def test_binding_geo_break_log(self): |  | ||||||
|         """Test logout_extra with exception""" |  | ||||||
|         # IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json |  | ||||||
|         for args, expect in [ |  | ||||||
|             [[GeoIPBinding.BIND_CONTINENT, "8.8.8.8", "8.8.8.8"], ["geoip.missing"]], |  | ||||||
|             [[GeoIPBinding.BIND_CONTINENT, "2.125.160.216", "67.43.156.1"], ["geoip.continent"]], |  | ||||||
|             [ |  | ||||||
|                 [GeoIPBinding.BIND_CONTINENT_COUNTRY, "81.2.69.142", "89.160.20.112"], |  | ||||||
|                 ["geoip.country"], |  | ||||||
|             ], |  | ||||||
|             [ |  | ||||||
|                 [GeoIPBinding.BIND_CONTINENT_COUNTRY_CITY, "2.125.160.216", "81.2.69.142"], |  | ||||||
|                 ["geoip.city"], |  | ||||||
|             ], |  | ||||||
|         ]: |  | ||||||
|             with self.subTest(args[0]): |  | ||||||
|                 with self.assertRaises(SessionBindingBroken) as cm: |  | ||||||
|                     BoundSessionMiddleware.recheck_session_geo(*args) |  | ||||||
|                 self.assertEqual(cm.exception.reason, expect[0]) |  | ||||||
|                 # Ensure the request can be logged without throwing errors |  | ||||||
|                 self.client.force_login(self.user) |  | ||||||
|                 request = HttpRequest() |  | ||||||
|                 request.session = self.client.session |  | ||||||
|                 request.user = self.user |  | ||||||
|                 logout_extra(request, cm.exception) |  | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	