Compare commits
	
		
			6 Commits
		
	
	
		
			website/do
			...
			files-in-d
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| d410083cfc | |||
| 6045f96a05 | |||
| c50df0f843 | |||
| c8ebd9f74b | |||
| b3f441f2cd | |||
| 647f054be3 | 
| @ -1,5 +1,5 @@ | ||||
| [bumpversion] | ||||
| current_version = 2025.6.3 | ||||
| current_version = 2025.6.1 | ||||
| tag = True | ||||
| commit = True | ||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | ||||
| @ -21,8 +21,6 @@ optional_value = final | ||||
|  | ||||
| [bumpversion:file:package.json] | ||||
|  | ||||
| [bumpversion:file:package-lock.json] | ||||
|  | ||||
| [bumpversion:file:docker-compose.yml] | ||||
|  | ||||
| [bumpversion:file:schema.yml] | ||||
| @ -33,4 +31,6 @@ optional_value = final | ||||
|  | ||||
| [bumpversion:file:internal/constants/constants.go] | ||||
|  | ||||
| [bumpversion:file:web/src/common/constants.ts] | ||||
|  | ||||
| [bumpversion:file:lifecycle/aws/template.yaml] | ||||
|  | ||||
| @ -7,9 +7,6 @@ charset = utf-8 | ||||
| trim_trailing_whitespace = true | ||||
| insert_final_newline = true | ||||
|  | ||||
| [*.toml] | ||||
| indent_size = 2 | ||||
|  | ||||
| [*.html] | ||||
| indent_size = 2 | ||||
|  | ||||
|  | ||||
| @ -38,8 +38,6 @@ jobs: | ||||
|       # Needed for attestation | ||||
|       id-token: write | ||||
|       attestations: write | ||||
|       # Needed for checkout | ||||
|       contents: read | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|       - uses: docker/setup-qemu-action@v3.6.0 | ||||
|  | ||||
							
								
								
									
										3
									
								
								.github/workflows/ci-main-daily.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/ci-main-daily.yml
									
									
									
									
										vendored
									
									
								
							| @ -9,15 +9,14 @@ on: | ||||
|  | ||||
| jobs: | ||||
|   test-container: | ||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} | ||||
|     runs-on: ubuntu-latest | ||||
|     strategy: | ||||
|       fail-fast: false | ||||
|       matrix: | ||||
|         version: | ||||
|           - docs | ||||
|           - version-2025-4 | ||||
|           - version-2025-2 | ||||
|           - version-2024-12 | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|       - run: | | ||||
|  | ||||
							
								
								
									
										6
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -202,7 +202,7 @@ jobs: | ||||
|         uses: actions/cache@v4 | ||||
|         with: | ||||
|           path: web/dist | ||||
|           key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b | ||||
|           key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b | ||||
|       - name: prepare web ui | ||||
|         if: steps.cache-web.outputs.cache-hit != 'true' | ||||
|         working-directory: web | ||||
| @ -247,13 +247,11 @@ jobs: | ||||
|       # Needed for attestation | ||||
|       id-token: write | ||||
|       attestations: write | ||||
|       # Needed for checkout | ||||
|       contents: read | ||||
|     needs: ci-core-mark | ||||
|     uses: ./.github/workflows/_reusable-docker-build.yaml | ||||
|     secrets: inherit | ||||
|     with: | ||||
|       image_name: ${{ github.repository == 'goauthentik/authentik-internal' && 'ghcr.io/goauthentik/internal-server' || 'ghcr.io/goauthentik/dev-server' }} | ||||
|       image_name: ghcr.io/goauthentik/dev-server | ||||
|       release: false | ||||
|   pr-comment: | ||||
|     needs: | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							| @ -59,7 +59,6 @@ jobs: | ||||
|         with: | ||||
|           jobs: ${{ toJSON(needs) }} | ||||
|   build-container: | ||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} | ||||
|     timeout-minutes: 120 | ||||
|     needs: | ||||
|       - ci-outpost-mark | ||||
|  | ||||
							
								
								
									
										24
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										24
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							| @ -41,29 +41,7 @@ jobs: | ||||
|       - name: test | ||||
|         working-directory: website/ | ||||
|         run: npm test | ||||
|   build: | ||||
|     runs-on: ubuntu-latest | ||||
|     name: ${{ matrix.job }} | ||||
|     strategy: | ||||
|       fail-fast: false | ||||
|       matrix: | ||||
|         job: | ||||
|           - build | ||||
|           - build:integrations | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|       - uses: actions/setup-node@v4 | ||||
|         with: | ||||
|           node-version-file: website/package.json | ||||
|           cache: "npm" | ||||
|           cache-dependency-path: website/package-lock.json | ||||
|       - working-directory: website/ | ||||
|         run: npm ci | ||||
|       - name: build | ||||
|         working-directory: website/ | ||||
|         run: npm run ${{ matrix.job }} | ||||
|   build-container: | ||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} | ||||
|     runs-on: ubuntu-latest | ||||
|     permissions: | ||||
|       # Needed to upload container images to ghcr.io | ||||
| @ -116,11 +94,9 @@ jobs: | ||||
|     needs: | ||||
|       - lint | ||||
|       - test | ||||
|       - build | ||||
|       - build-container | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - uses: re-actors/alls-green@release/v1 | ||||
|         with: | ||||
|           jobs: ${{ toJSON(needs) }} | ||||
|           allowed-skips: ${{ github.repository == 'goauthentik/authentik-internal' && 'build-container' || '[]' }} | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/codeql-analysis.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/codeql-analysis.yml
									
									
									
									
										vendored
									
									
								
							| @ -2,7 +2,7 @@ name: "CodeQL" | ||||
|  | ||||
| on: | ||||
|   push: | ||||
|     branches: [main, next, version*] | ||||
|     branches: [main, "*", next, version*] | ||||
|   pull_request: | ||||
|     branches: [main] | ||||
|   schedule: | ||||
|  | ||||
							
								
								
									
										21
									
								
								.github/workflows/repo-mirror-cleanup.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										21
									
								
								.github/workflows/repo-mirror-cleanup.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,21 +0,0 @@ | ||||
| name: "authentik-repo-mirror-cleanup" | ||||
|  | ||||
| on: | ||||
|   workflow_dispatch: | ||||
|  | ||||
| jobs: | ||||
|   to_internal: | ||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|         with: | ||||
|           fetch-depth: 0 | ||||
|       - if: ${{ env.MIRROR_KEY != '' }} | ||||
|         uses: BeryJu/repository-mirroring-action@5cf300935bc2e068f73ea69bcc411a8a997208eb | ||||
|         with: | ||||
|           target_repo_url: git@github.com:goauthentik/authentik-internal.git | ||||
|           ssh_private_key: ${{ secrets.GH_MIRROR_KEY }} | ||||
|           args: --tags --force --prune | ||||
|         env: | ||||
|           MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }} | ||||
							
								
								
									
										9
									
								
								.github/workflows/repo-mirror.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								.github/workflows/repo-mirror.yml
									
									
									
									
										vendored
									
									
								
							| @ -11,10 +11,11 @@ jobs: | ||||
|         with: | ||||
|           fetch-depth: 0 | ||||
|       - if: ${{ env.MIRROR_KEY != '' }} | ||||
|         uses: BeryJu/repository-mirroring-action@5cf300935bc2e068f73ea69bcc411a8a997208eb | ||||
|         uses: pixta-dev/repository-mirroring-action@v1 | ||||
|         with: | ||||
|           target_repo_url: git@github.com:goauthentik/authentik-internal.git | ||||
|           ssh_private_key: ${{ secrets.GH_MIRROR_KEY }} | ||||
|           args: --tags --force | ||||
|           target_repo_url: | ||||
|             git@github.com:goauthentik/authentik-internal.git | ||||
|           ssh_private_key: | ||||
|             ${{ secrets.GH_MIRROR_KEY }} | ||||
|         env: | ||||
|           MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }} | ||||
|  | ||||
| @ -16,7 +16,6 @@ env: | ||||
|  | ||||
| jobs: | ||||
|   compile: | ||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - id: generate_token | ||||
|  | ||||
							
								
								
									
										4
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
								
							| @ -6,15 +6,13 @@ | ||||
|         "!Context scalar", | ||||
|         "!Enumerate sequence", | ||||
|         "!Env scalar", | ||||
|         "!Env sequence", | ||||
|         "!Find sequence", | ||||
|         "!Format sequence", | ||||
|         "!If sequence", | ||||
|         "!Index scalar", | ||||
|         "!KeyOf scalar", | ||||
|         "!Value scalar", | ||||
|         "!AtIndex scalar", | ||||
|         "!ParseJSON scalar" | ||||
|         "!AtIndex scalar" | ||||
|     ], | ||||
|     "typescript.preferences.importModuleSpecifier": "non-relative", | ||||
|     "typescript.preferences.importModuleSpecifierEnding": "index", | ||||
|  | ||||
| @ -75,9 +75,9 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \ | ||||
|     /bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" | ||||
|  | ||||
| # Stage 4: Download uv | ||||
| FROM ghcr.io/astral-sh/uv:0.7.17 AS uv | ||||
| FROM ghcr.io/astral-sh/uv:0.7.13 AS uv | ||||
| # Stage 5: Base python image | ||||
| FROM ghcr.io/goauthentik/fips-python:3.13.5-slim-bookworm-fips AS python-base | ||||
| FROM ghcr.io/goauthentik/fips-python:3.13.4-slim-bookworm-fips AS python-base | ||||
|  | ||||
| ENV VENV_PATH="/ak-root/.venv" \ | ||||
|     PATH="/lifecycle:/ak-root/.venv/bin:$PATH" \ | ||||
|  | ||||
							
								
								
									
										12
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								Makefile
									
									
									
									
									
								
							| @ -86,10 +86,6 @@ dev-create-db: | ||||
|  | ||||
| dev-reset: dev-drop-db dev-create-db migrate  ## Drop and restore the Authentik PostgreSQL instance to a "fresh install" state. | ||||
|  | ||||
| update-test-mmdb:  ## Update test GeoIP and ASN Databases | ||||
| 	curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-ASN-Test.mmdb -o ${PWD}/tests/GeoLite2-ASN-Test.mmdb | ||||
| 	curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-City-Test.mmdb -o ${PWD}/tests/GeoLite2-City-Test.mmdb | ||||
|  | ||||
| ######################### | ||||
| ## API Schema | ||||
| ######################### | ||||
| @ -98,7 +94,7 @@ gen-build:  ## Extract the schema from the database | ||||
| 	AUTHENTIK_DEBUG=true \ | ||||
| 		AUTHENTIK_TENANTS__ENABLED=true \ | ||||
| 		AUTHENTIK_OUTPOSTS__DISABLE_EMBEDDED_OUTPOST=true \ | ||||
| 		uv run ak make_blueprint_schema --file blueprints/schema.json | ||||
| 		uv run ak make_blueprint_schema > blueprints/schema.json | ||||
| 	AUTHENTIK_DEBUG=true \ | ||||
| 		AUTHENTIK_TENANTS__ENABLED=true \ | ||||
| 		AUTHENTIK_OUTPOSTS__DISABLE_EMBEDDED_OUTPOST=true \ | ||||
| @ -150,9 +146,9 @@ gen-client-ts: gen-clean-ts  ## Build and install the authentik API for Typescri | ||||
| 		--additional-properties=npmVersion=${NPM_VERSION} \ | ||||
| 		--git-repo-id authentik \ | ||||
| 		--git-user-id goauthentik | ||||
|  | ||||
| 	cd ${PWD}/${GEN_API_TS} && npm link | ||||
| 	cd ${PWD}/web && npm link @goauthentik/api | ||||
| 	mkdir -p web/node_modules/@goauthentik/api | ||||
| 	cd ${PWD}/${GEN_API_TS} && npm i | ||||
| 	\cp -rf ${PWD}/${GEN_API_TS}/* web/node_modules/@goauthentik/api | ||||
|  | ||||
| gen-client-py: gen-clean-py ## Build and install the authentik API for Python | ||||
| 	docker run \ | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
|  | ||||
| from os import environ | ||||
|  | ||||
| __version__ = "2025.6.3" | ||||
| __version__ = "2025.6.1" | ||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -72,33 +72,20 @@ class Command(BaseCommand): | ||||
|                     "additionalProperties": True, | ||||
|                 }, | ||||
|                 "entries": { | ||||
|                     "anyOf": [ | ||||
|                         { | ||||
|                             "type": "array", | ||||
|                             "items": {"$ref": "#/$defs/blueprint_entry"}, | ||||
|                         }, | ||||
|                         { | ||||
|                             "type": "object", | ||||
|                             "additionalProperties": { | ||||
|                                 "type": "array", | ||||
|                                 "items": {"$ref": "#/$defs/blueprint_entry"}, | ||||
|                             }, | ||||
|                         }, | ||||
|                     ], | ||||
|                     "type": "array", | ||||
|                     "items": { | ||||
|                         "oneOf": [], | ||||
|                     }, | ||||
|                 }, | ||||
|             }, | ||||
|             "$defs": {"blueprint_entry": {"oneOf": []}}, | ||||
|             "$defs": {}, | ||||
|         } | ||||
|  | ||||
|     def add_arguments(self, parser): | ||||
|         parser.add_argument("--file", type=str) | ||||
|  | ||||
|     @no_translations | ||||
|     def handle(self, *args, file: str, **options): | ||||
|     def handle(self, *args, **options): | ||||
|         """Generate JSON Schema for blueprints""" | ||||
|         self.build() | ||||
|         with open(file, "w") as _schema: | ||||
|             _schema.write(dumps(self.schema, indent=4, default=Command.json_default)) | ||||
|         self.stdout.write(dumps(self.schema, indent=4, default=Command.json_default)) | ||||
|  | ||||
|     @staticmethod | ||||
|     def json_default(value: Any) -> Any: | ||||
| @ -125,7 +112,7 @@ class Command(BaseCommand): | ||||
|                 } | ||||
|             ) | ||||
|             model_path = f"{model._meta.app_label}.{model._meta.model_name}" | ||||
|             self.schema["$defs"]["blueprint_entry"]["oneOf"].append( | ||||
|             self.schema["properties"]["entries"]["items"]["oneOf"].append( | ||||
|                 self.template_entry(model_path, model, serializer) | ||||
|             ) | ||||
|  | ||||
|  | ||||
| @ -1,11 +1,10 @@ | ||||
| version: 1 | ||||
| entries: | ||||
|   foo: | ||||
|       - identifiers: | ||||
|             name: "%(id)s" | ||||
|             slug: "%(id)s" | ||||
|         model: authentik_flows.flow | ||||
|         state: present | ||||
|         attrs: | ||||
|             designation: stage_configuration | ||||
|             title: foo | ||||
|     - identifiers: | ||||
|           name: "%(id)s" | ||||
|           slug: "%(id)s" | ||||
|       model: authentik_flows.flow | ||||
|       state: present | ||||
|       attrs: | ||||
|           designation: stage_configuration | ||||
|           title: foo | ||||
|  | ||||
| @ -37,7 +37,6 @@ entries: | ||||
|     - attrs: | ||||
|           attributes: | ||||
|               env_null: !Env [bar-baz, null] | ||||
|               json_parse: !ParseJSON '{"foo": "bar"}' | ||||
|               policy_pk1: | ||||
|                   !Format [ | ||||
|                       "%s-%s", | ||||
|  | ||||
| @ -35,6 +35,6 @@ def blueprint_tester(file_name: Path) -> Callable: | ||||
|  | ||||
|  | ||||
| for blueprint_file in Path("blueprints/").glob("**/*.yaml"): | ||||
|     if "local" in str(blueprint_file) or "testing" in str(blueprint_file): | ||||
|     if "local" in str(blueprint_file): | ||||
|         continue | ||||
|     setattr(TestPackaged, f"test_blueprint_{blueprint_file}", blueprint_tester(blueprint_file)) | ||||
|  | ||||
| @ -5,6 +5,7 @@ from collections.abc import Callable | ||||
| from django.apps import apps | ||||
| from django.test import TestCase | ||||
|  | ||||
| from authentik.blueprints.v1.importer import is_model_allowed | ||||
| from authentik.lib.models import SerializerModel | ||||
| from authentik.providers.oauth2.models import RefreshToken | ||||
|  | ||||
| @ -21,13 +22,10 @@ def serializer_tester_factory(test_model: type[SerializerModel]) -> Callable: | ||||
|             return | ||||
|         model_class = test_model() | ||||
|         self.assertTrue(isinstance(model_class, SerializerModel)) | ||||
|         # Models that have subclasses don't have to have a serializer | ||||
|         if len(test_model.__subclasses__()) > 0: | ||||
|             return | ||||
|         self.assertIsNotNone(model_class.serializer) | ||||
|         if model_class.serializer.Meta().model == RefreshToken: | ||||
|             return | ||||
|         self.assertTrue(issubclass(test_model, model_class.serializer.Meta().model)) | ||||
|         self.assertEqual(model_class.serializer.Meta().model, test_model) | ||||
|  | ||||
|     return tester | ||||
|  | ||||
| @ -36,6 +34,6 @@ for app in apps.get_app_configs(): | ||||
|     if not app.label.startswith("authentik"): | ||||
|         continue | ||||
|     for model in app.get_models(): | ||||
|         if not issubclass(model, SerializerModel): | ||||
|         if not is_model_allowed(model): | ||||
|             continue | ||||
|         setattr(TestModels, f"test_{app.label}_{model.__name__}", serializer_tester_factory(model)) | ||||
|  | ||||
| @ -215,7 +215,6 @@ class TestBlueprintsV1(TransactionTestCase): | ||||
|                     }, | ||||
|                     "nested_context": "context-nested-value", | ||||
|                     "env_null": None, | ||||
|                     "json_parse": {"foo": "bar"}, | ||||
|                     "at_index_sequence": "foo", | ||||
|                     "at_index_sequence_default": "non existent", | ||||
|                     "at_index_mapping": 2, | ||||
|  | ||||
| @ -6,7 +6,6 @@ from copy import copy | ||||
| from dataclasses import asdict, dataclass, field, is_dataclass | ||||
| from enum import Enum | ||||
| from functools import reduce | ||||
| from json import JSONDecodeError, loads | ||||
| from operator import ixor | ||||
| from os import getenv | ||||
| from typing import Any, Literal, Union | ||||
| @ -192,18 +191,11 @@ class Blueprint: | ||||
|     """Dataclass used for a full export""" | ||||
|  | ||||
|     version: int = field(default=1) | ||||
|     entries: list[BlueprintEntry] | dict[str, list[BlueprintEntry]] = field(default_factory=list) | ||||
|     entries: list[BlueprintEntry] = field(default_factory=list) | ||||
|     context: dict = field(default_factory=dict) | ||||
|  | ||||
|     metadata: BlueprintMetadata | None = field(default=None) | ||||
|  | ||||
|     def iter_entries(self) -> Iterable[BlueprintEntry]: | ||||
|         if isinstance(self.entries, dict): | ||||
|             for _section, entries in self.entries.items(): | ||||
|                 yield from entries | ||||
|         else: | ||||
|             yield from self.entries | ||||
|  | ||||
|  | ||||
| class YAMLTag: | ||||
|     """Base class for all YAML Tags""" | ||||
| @ -234,7 +226,7 @@ class KeyOf(YAMLTag): | ||||
|         self.id_from = node.value | ||||
|  | ||||
|     def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: | ||||
|         for _entry in blueprint.iter_entries(): | ||||
|         for _entry in blueprint.entries: | ||||
|             if _entry.id == self.id_from and _entry._state.instance: | ||||
|                 # Special handling for PolicyBindingModels, as they'll have a different PK | ||||
|                 # which is used when creating policy bindings | ||||
| @ -292,22 +284,6 @@ class Context(YAMLTag): | ||||
|         return value | ||||
|  | ||||
|  | ||||
| class ParseJSON(YAMLTag): | ||||
|     """Parse JSON from context/env/etc value""" | ||||
|  | ||||
|     raw: str | ||||
|  | ||||
|     def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None: | ||||
|         super().__init__() | ||||
|         self.raw = node.value | ||||
|  | ||||
|     def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: | ||||
|         try: | ||||
|             return loads(self.raw) | ||||
|         except JSONDecodeError as exc: | ||||
|             raise EntryInvalidError.from_entry(exc, entry) from exc | ||||
|  | ||||
|  | ||||
| class Format(YAMLTag): | ||||
|     """Format a string""" | ||||
|  | ||||
| @ -683,7 +659,6 @@ class BlueprintLoader(SafeLoader): | ||||
|         self.add_constructor("!Value", Value) | ||||
|         self.add_constructor("!Index", Index) | ||||
|         self.add_constructor("!AtIndex", AtIndex) | ||||
|         self.add_constructor("!ParseJSON", ParseJSON) | ||||
|  | ||||
|  | ||||
| class EntryInvalidError(SentryIgnoredException): | ||||
|  | ||||
| @ -384,7 +384,7 @@ class Importer: | ||||
|     def _apply_models(self, raise_errors=False) -> bool: | ||||
|         """Apply (create/update) models yaml""" | ||||
|         self.__pk_map = {} | ||||
|         for entry in self._import.iter_entries(): | ||||
|         for entry in self._import.entries: | ||||
|             model_app_label, model_name = entry.get_model(self._import).split(".") | ||||
|             try: | ||||
|                 model: type[SerializerModel] = registry.get_model(model_app_label, model_name) | ||||
|  | ||||
| @ -1,6 +1,8 @@ | ||||
| """Authenticator Devices API Views""" | ||||
|  | ||||
| from drf_spectacular.utils import extend_schema | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from drf_spectacular.types import OpenApiTypes | ||||
| from drf_spectacular.utils import OpenApiParameter, extend_schema | ||||
| from guardian.shortcuts import get_objects_for_user | ||||
| from rest_framework.fields import ( | ||||
|     BooleanField, | ||||
| @ -13,7 +15,6 @@ from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
| from rest_framework.viewsets import ViewSet | ||||
|  | ||||
| from authentik.core.api.users import ParamUserSerializer | ||||
| from authentik.core.api.utils import MetaNameSerializer | ||||
| from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import EndpointDevice | ||||
| from authentik.stages.authenticator import device_classes, devices_for_user | ||||
| @ -22,7 +23,7 @@ from authentik.stages.authenticator_webauthn.models import WebAuthnDevice | ||||
|  | ||||
|  | ||||
| class DeviceSerializer(MetaNameSerializer): | ||||
|     """Serializer for authenticator devices""" | ||||
|     """Serializer for Duo authenticator devices""" | ||||
|  | ||||
|     pk = CharField() | ||||
|     name = CharField() | ||||
| @ -32,27 +33,22 @@ class DeviceSerializer(MetaNameSerializer): | ||||
|     last_updated = DateTimeField(read_only=True) | ||||
|     last_used = DateTimeField(read_only=True, allow_null=True) | ||||
|     extra_description = SerializerMethodField() | ||||
|     external_id = SerializerMethodField() | ||||
|  | ||||
|     def get_type(self, instance: Device) -> str: | ||||
|         """Get type of device""" | ||||
|         return instance._meta.label | ||||
|  | ||||
|     def get_extra_description(self, instance: Device) -> str | None: | ||||
|     def get_extra_description(self, instance: Device) -> str: | ||||
|         """Get extra description""" | ||||
|         if isinstance(instance, WebAuthnDevice): | ||||
|             return instance.device_type.description if instance.device_type else None | ||||
|             return ( | ||||
|                 instance.device_type.description | ||||
|                 if instance.device_type | ||||
|                 else _("Extra description not available") | ||||
|             ) | ||||
|         if isinstance(instance, EndpointDevice): | ||||
|             return instance.data.get("deviceSignals", {}).get("deviceModel") | ||||
|         return None | ||||
|  | ||||
|     def get_external_id(self, instance: Device) -> str | None: | ||||
|         """Get external Device ID""" | ||||
|         if isinstance(instance, WebAuthnDevice): | ||||
|             return instance.device_type.aaguid if instance.device_type else None | ||||
|         if isinstance(instance, EndpointDevice): | ||||
|             return instance.data.get("deviceSignals", {}).get("deviceModel") | ||||
|         return None | ||||
|         return "" | ||||
|  | ||||
|  | ||||
| class DeviceViewSet(ViewSet): | ||||
| @ -61,6 +57,7 @@ class DeviceViewSet(ViewSet): | ||||
|     serializer_class = DeviceSerializer | ||||
|     permission_classes = [IsAuthenticated] | ||||
|  | ||||
|     @extend_schema(responses={200: DeviceSerializer(many=True)}) | ||||
|     def list(self, request: Request) -> Response: | ||||
|         """Get all devices for current user""" | ||||
|         devices = devices_for_user(request.user) | ||||
| @ -82,11 +79,18 @@ class AdminDeviceViewSet(ViewSet): | ||||
|             yield from device_set | ||||
|  | ||||
|     @extend_schema( | ||||
|         parameters=[ParamUserSerializer], | ||||
|         parameters=[ | ||||
|             OpenApiParameter( | ||||
|                 name="user", | ||||
|                 location=OpenApiParameter.QUERY, | ||||
|                 type=OpenApiTypes.INT, | ||||
|             ) | ||||
|         ], | ||||
|         responses={200: DeviceSerializer(many=True)}, | ||||
|     ) | ||||
|     def list(self, request: Request) -> Response: | ||||
|         """Get all devices for current user""" | ||||
|         args = ParamUserSerializer(data=request.query_params) | ||||
|         args.is_valid(raise_exception=True) | ||||
|         return Response(DeviceSerializer(self.get_devices(**args.validated_data), many=True).data) | ||||
|         kwargs = {} | ||||
|         if "user" in request.query_params: | ||||
|             kwargs = {"user": request.query_params["user"]} | ||||
|         return Response(DeviceSerializer(self.get_devices(**kwargs), many=True).data) | ||||
|  | ||||
							
								
								
									
										32
									
								
								authentik/core/api/files.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								authentik/core/api/files.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,32 @@ | ||||
| from rest_framework.viewsets import ModelViewSet | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.core.api.utils import ModelSerializer | ||||
| from authentik.core.models import File | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| class FileSerializer(ModelSerializer): | ||||
|     class Meta: | ||||
|         model = File | ||||
|         fields = ( | ||||
|             "pk", | ||||
|             "name", | ||||
|             "content", | ||||
|             "location", | ||||
|             "private", | ||||
|             "url", | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class FileViewSet(UsedByMixin, ModelViewSet): | ||||
|     queryset = File.objects.all() | ||||
|     serializer_class = FileSerializer | ||||
|     filterset_fields = ("private",) | ||||
|     ordering = ("name",) | ||||
|     search_fields = ( | ||||
|         "name", | ||||
|         "location", | ||||
|     ) | ||||
| @ -90,12 +90,6 @@ from authentik.stages.email.utils import TemplateEmailMessage | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| class ParamUserSerializer(PassiveSerializer): | ||||
|     """Partial serializer for query parameters to select a user""" | ||||
|  | ||||
|     user = PrimaryKeyRelatedField(queryset=User.objects.all().exclude_anonymous(), required=False) | ||||
|  | ||||
|  | ||||
| class UserGroupSerializer(ModelSerializer): | ||||
|     """Simplified Group Serializer for user's groups""" | ||||
|  | ||||
| @ -392,23 +386,8 @@ class UserViewSet(UsedByMixin, ModelViewSet): | ||||
|     queryset = User.objects.none() | ||||
|     ordering = ["username"] | ||||
|     serializer_class = UserSerializer | ||||
|     filterset_class = UsersFilter | ||||
|     search_fields = ["username", "name", "is_active", "email", "uuid", "attributes"] | ||||
|  | ||||
|     def get_ql_fields(self): | ||||
|         from djangoql.schema import BoolField, StrField | ||||
|  | ||||
|         from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField | ||||
|  | ||||
|         return [ | ||||
|             StrField(User, "username"), | ||||
|             StrField(User, "name"), | ||||
|             StrField(User, "email"), | ||||
|             StrField(User, "path"), | ||||
|             BoolField(User, "is_active", nullable=True), | ||||
|             ChoiceSearchField(User, "type"), | ||||
|             JSONSearchField(User, "attributes", suggest_nested=False), | ||||
|         ] | ||||
|     filterset_class = UsersFilter | ||||
|  | ||||
|     def get_queryset(self): | ||||
|         base_qs = User.objects.all().exclude_anonymous() | ||||
|  | ||||
| @ -2,7 +2,6 @@ | ||||
|  | ||||
| from typing import Any | ||||
|  | ||||
| from django.db import models | ||||
| from django.db.models import Model | ||||
| from drf_spectacular.extensions import OpenApiSerializerFieldExtension | ||||
| from drf_spectacular.plumbing import build_basic_type | ||||
| @ -31,27 +30,7 @@ def is_dict(value: Any): | ||||
|     raise ValidationError("Value must be a dictionary, and not have any duplicate keys.") | ||||
|  | ||||
|  | ||||
| class JSONDictField(JSONField): | ||||
|     """JSON Field which only allows dictionaries""" | ||||
|  | ||||
|     default_validators = [is_dict] | ||||
|  | ||||
|  | ||||
| class JSONExtension(OpenApiSerializerFieldExtension): | ||||
|     """Generate API Schema for JSON fields as""" | ||||
|  | ||||
|     target_class = "authentik.core.api.utils.JSONDictField" | ||||
|  | ||||
|     def map_serializer_field(self, auto_schema, direction): | ||||
|         return build_basic_type(OpenApiTypes.OBJECT) | ||||
|  | ||||
|  | ||||
| class ModelSerializer(BaseModelSerializer): | ||||
|  | ||||
|     # By default, JSON fields we have are used to store dictionaries | ||||
|     serializer_field_mapping = BaseModelSerializer.serializer_field_mapping.copy() | ||||
|     serializer_field_mapping[models.JSONField] = JSONDictField | ||||
|  | ||||
|     def create(self, validated_data): | ||||
|         instance = super().create(validated_data) | ||||
|  | ||||
| @ -92,6 +71,21 @@ class ModelSerializer(BaseModelSerializer): | ||||
|         return instance | ||||
|  | ||||
|  | ||||
| class JSONDictField(JSONField): | ||||
|     """JSON Field which only allows dictionaries""" | ||||
|  | ||||
|     default_validators = [is_dict] | ||||
|  | ||||
|  | ||||
| class JSONExtension(OpenApiSerializerFieldExtension): | ||||
|     """Generate API Schema for JSON fields as""" | ||||
|  | ||||
|     target_class = "authentik.core.api.utils.JSONDictField" | ||||
|  | ||||
|     def map_serializer_field(self, auto_schema, direction): | ||||
|         return build_basic_type(OpenApiTypes.OBJECT) | ||||
|  | ||||
|  | ||||
| class PassiveSerializer(Serializer): | ||||
|     """Base serializer class which doesn't implement create/update methods""" | ||||
|  | ||||
|  | ||||
| @ -13,6 +13,7 @@ class Command(TenantCommand): | ||||
|         parser.add_argument("usernames", nargs="*", type=str) | ||||
|  | ||||
|     def handle_per_tenant(self, **options): | ||||
|         print(options) | ||||
|         new_type = UserTypes(options["type"]) | ||||
|         qs = ( | ||||
|             User.objects.exclude_anonymous() | ||||
|  | ||||
| @ -0,0 +1,44 @@ | ||||
| # Generated by Django 5.1.11 on 2025-06-13 15:12 | ||||
|  | ||||
| import uuid | ||||
| import django.db.models.deletion | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_core", "0048_delete_oldauthenticatedsession_content_type"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.CreateModel( | ||||
|             name="File", | ||||
|             fields=[ | ||||
|                 ( | ||||
|                     "id", | ||||
|                     models.UUIDField( | ||||
|                         default=uuid.uuid4, editable=False, primary_key=True, serialize=False | ||||
|                     ), | ||||
|                 ), | ||||
|                 ("name", models.TextField()), | ||||
|                 ("content", models.BinaryField()), | ||||
|                 ("public", models.BooleanField(default=False)), | ||||
|             ], | ||||
|             options={ | ||||
|                 "verbose_name": "Files", | ||||
|             }, | ||||
|         ), | ||||
|         migrations.RenameField( | ||||
|             model_name="application", | ||||
|             old_name="meta_icon", | ||||
|             new_name="meta_old_icon", | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="application", | ||||
|             name="meta_icon", | ||||
|             field=models.ForeignKey( | ||||
|                 null=True, on_delete=django.db.models.deletion.SET_NULL, to="authentik_core.file" | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
| @ -0,0 +1,32 @@ | ||||
| # Generated by Django 5.1.11 on 2025-06-13 15:29 | ||||
|  | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_core", "0049_file_rename_meta_icon_application_meta_old_icon"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AddField( | ||||
|             model_name="file", | ||||
|             name="location", | ||||
|             field=models.TextField(null=True), | ||||
|         ), | ||||
|         migrations.AlterField( | ||||
|             model_name="file", | ||||
|             name="content", | ||||
|             field=models.BinaryField(null=True), | ||||
|         ), | ||||
|         migrations.AddConstraint( | ||||
|             model_name="file", | ||||
|             constraint=models.CheckConstraint( | ||||
|                 condition=models.Q( | ||||
|                     ("content__isnull", False), ("location__isnull", False), _connector="OR" | ||||
|                 ), | ||||
|                 name="one_of_content_location_is_defined", | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
| @ -18,7 +18,7 @@ from django.http import HttpRequest | ||||
| from django.utils.functional import SimpleLazyObject, cached_property | ||||
| from django.utils.timezone import now | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from django_cte import CTE, with_cte | ||||
| from django_cte import CTEQuerySet, With | ||||
| from guardian.conf import settings | ||||
| from guardian.mixins import GuardianUserMixin | ||||
| from model_utils.managers import InheritanceManager | ||||
| @ -29,6 +29,7 @@ from authentik.blueprints.models import ManagedModel | ||||
| from authentik.core.expression.exceptions import PropertyMappingExpressionException | ||||
| from authentik.core.types import UILoginButton, UserSettingSerializer | ||||
| from authentik.lib.avatars import get_avatar | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.lib.expression.exceptions import ControlFlowException | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.lib.merge import MERGE_LIST_UNIQUE | ||||
| @ -136,7 +137,7 @@ class AttributesMixin(models.Model): | ||||
|         return instance, False | ||||
|  | ||||
|  | ||||
| class GroupQuerySet(QuerySet): | ||||
| class GroupQuerySet(CTEQuerySet): | ||||
|     def with_children_recursive(self): | ||||
|         """Recursively get all groups that have the current queryset as parents | ||||
|         or are indirectly related.""" | ||||
| @ -165,9 +166,9 @@ class GroupQuerySet(QuerySet): | ||||
|             ) | ||||
|  | ||||
|         # Build the recursive query, see above | ||||
|         cte = CTE.recursive(make_cte) | ||||
|         cte = With.recursive(make_cte) | ||||
|         # Return the result, as a usable queryset for Group. | ||||
|         return with_cte(cte, select=cte.join(Group, group_uuid=cte.col.group_uuid)) | ||||
|         return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte) | ||||
|  | ||||
|  | ||||
| class Group(SerializerModel, AttributesMixin): | ||||
| @ -533,12 +534,13 @@ class Application(SerializerModel, PolicyBindingModel): | ||||
|     ) | ||||
|  | ||||
|     # For template applications, this can be set to /static/authentik/applications/* | ||||
|     meta_icon = models.FileField( | ||||
|     meta_old_icon = models.FileField( | ||||
|         upload_to="application-icons/", | ||||
|         default=None, | ||||
|         null=True, | ||||
|         max_length=500, | ||||
|     ) | ||||
|     meta_icon = models.ForeignKey("File", null=True, on_delete=models.SET_NULL) | ||||
|     meta_description = models.TextField(default="", blank=True) | ||||
|     meta_publisher = models.TextField(default="", blank=True) | ||||
|  | ||||
| @ -1082,12 +1084,6 @@ class AuthenticatedSession(SerializerModel): | ||||
|  | ||||
|     user = models.ForeignKey(User, on_delete=models.CASCADE) | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> type[Serializer]: | ||||
|         from authentik.core.api.authenticated_sessions import AuthenticatedSessionSerializer | ||||
|  | ||||
|         return AuthenticatedSessionSerializer | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("Authenticated Session") | ||||
|         verbose_name_plural = _("Authenticated Sessions") | ||||
| @ -1106,3 +1102,44 @@ class AuthenticatedSession(SerializerModel): | ||||
|             session=Session.objects.filter(session_key=request.session.session_key).first(), | ||||
|             user=user, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class File(SerializerModel): | ||||
|     id = models.UUIDField(primary_key=True, editable=False, default=uuid4) | ||||
|  | ||||
|     name = models.TextField() | ||||
|     content = models.BinaryField(null=True) | ||||
|     location = models.TextField(null=True) | ||||
|     public = models.BooleanField(default=False) | ||||
|     delete_on_delete = models.BooleanField(default=False) | ||||
|     expiry = models.DateTimeField() | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("File") | ||||
|         verbose_name = _("Files") | ||||
|         constraints = ( | ||||
|             models.CheckConstraint( | ||||
|                 condition=Q(content__isnull=False) | Q(location__isnull=False), | ||||
|                 name="one_of_content_location_is_defined", | ||||
|             ), | ||||
|         ) | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         return self.name | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> type[Serializer]: | ||||
|         from authentik.core.api.files import FileSerializer | ||||
|  | ||||
|         return FileSerializer | ||||
|  | ||||
|     @property | ||||
|     def url(self) -> str: | ||||
|         if self.content: | ||||
|             return ( | ||||
|                 CONFIG.get("web.path", "/")[:-1] | ||||
|                 + f"/files/{'public' if self.public else 'private'}/{self.pk}" | ||||
|             ) | ||||
|         elif self.location.startswith("/static"): | ||||
|             return CONFIG.get("web.path", "/")[:-1] + self.location | ||||
|         return self.location | ||||
|  | ||||
| @ -114,7 +114,6 @@ class TestApplicationsAPI(APITestCase): | ||||
|         self.assertJSONEqual( | ||||
|             response.content.decode(), | ||||
|             { | ||||
|                 "autocomplete": {}, | ||||
|                 "pagination": { | ||||
|                     "next": 0, | ||||
|                     "previous": 0, | ||||
| @ -168,7 +167,6 @@ class TestApplicationsAPI(APITestCase): | ||||
|         self.assertJSONEqual( | ||||
|             response.content.decode(), | ||||
|             { | ||||
|                 "autocomplete": {}, | ||||
|                 "pagination": { | ||||
|                     "next": 0, | ||||
|                     "previous": 0, | ||||
|  | ||||
| @ -8,6 +8,7 @@ from authentik.core.api.application_entitlements import ApplicationEntitlementVi | ||||
| from authentik.core.api.applications import ApplicationViewSet | ||||
| from authentik.core.api.authenticated_sessions import AuthenticatedSessionViewSet | ||||
| from authentik.core.api.devices import AdminDeviceViewSet, DeviceViewSet | ||||
| from authentik.core.api.files import FileViewSet | ||||
| from authentik.core.api.groups import GroupViewSet | ||||
| from authentik.core.api.property_mappings import PropertyMappingViewSet | ||||
| from authentik.core.api.providers import ProviderViewSet | ||||
| @ -78,6 +79,7 @@ api_urlpatterns = [ | ||||
|         TransactionalApplicationView.as_view(), | ||||
|         name="core-transactional-application", | ||||
|     ), | ||||
|     ("core/files", FileViewSet), | ||||
|     ("core/groups", GroupViewSet), | ||||
|     ("core/users", UserViewSet), | ||||
|     ("core/tokens", TokenViewSet), | ||||
|  | ||||
| @ -119,17 +119,17 @@ class TestTrimPasswordHistory(TestCase): | ||||
|             [ | ||||
|                 UserPasswordHistory( | ||||
|                     user=self.user, | ||||
|                     old_password="hunter1",  # nosec | ||||
|                     old_password="hunter1",  # nosec B106 | ||||
|                     created_at=_now - timedelta(days=3), | ||||
|                 ), | ||||
|                 UserPasswordHistory( | ||||
|                     user=self.user, | ||||
|                     old_password="hunter2",  # nosec | ||||
|                     old_password="hunter2",  # nosec B106 | ||||
|                     created_at=_now - timedelta(days=2), | ||||
|                 ), | ||||
|                 UserPasswordHistory( | ||||
|                     user=self.user, | ||||
|                     old_password="hunter3",  # nosec | ||||
|                     old_password="hunter3",  # nosec B106 | ||||
|                     created_at=_now, | ||||
|                 ), | ||||
|             ] | ||||
|  | ||||
| @ -1,8 +1,10 @@ | ||||
| from hashlib import sha256 | ||||
|  | ||||
| from django.contrib.auth.signals import user_logged_out | ||||
| from django.db.models import Model | ||||
| from django.db.models.signals import post_delete, post_save, pre_delete | ||||
| from django.dispatch import receiver | ||||
| from django.http.request import HttpRequest | ||||
| from guardian.shortcuts import assign_perm | ||||
|  | ||||
| from authentik.core.models import ( | ||||
| @ -60,6 +62,31 @@ def ssf_providers_post_save(sender: type[Model], instance: SSFProvider, created: | ||||
|             instance.save() | ||||
|  | ||||
|  | ||||
| @receiver(user_logged_out) | ||||
| def ssf_user_logged_out_session_revoked(sender, request: HttpRequest, user: User, **_): | ||||
|     """Session revoked trigger (user logged out)""" | ||||
|     if not request.session or not request.session.session_key or not user: | ||||
|         return | ||||
|     send_ssf_event( | ||||
|         EventTypes.CAEP_SESSION_REVOKED, | ||||
|         { | ||||
|             "initiating_entity": "user", | ||||
|         }, | ||||
|         sub_id={ | ||||
|             "format": "complex", | ||||
|             "session": { | ||||
|                 "format": "opaque", | ||||
|                 "id": sha256(request.session.session_key.encode("ascii")).hexdigest(), | ||||
|             }, | ||||
|             "user": { | ||||
|                 "format": "email", | ||||
|                 "email": user.email, | ||||
|             }, | ||||
|         }, | ||||
|         request=request, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @receiver(pre_delete, sender=AuthenticatedSession) | ||||
| def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSession, **_): | ||||
|     """Session revoked trigger (users' session has been deleted) | ||||
|  | ||||
| @ -1,12 +0,0 @@ | ||||
| """Enterprise app config""" | ||||
|  | ||||
| from authentik.enterprise.apps import EnterpriseConfig | ||||
|  | ||||
|  | ||||
| class AuthentikEnterpriseSearchConfig(EnterpriseConfig): | ||||
|     """Enterprise app config""" | ||||
|  | ||||
|     name = "authentik.enterprise.search" | ||||
|     label = "authentik_search" | ||||
|     verbose_name = "authentik Enterprise.Search" | ||||
|     default = True | ||||
| @ -1,128 +0,0 @@ | ||||
| """DjangoQL search""" | ||||
|  | ||||
| from collections import OrderedDict, defaultdict | ||||
| from collections.abc import Generator | ||||
|  | ||||
| from django.db import connection | ||||
| from django.db.models import Model, Q | ||||
| from djangoql.compat import text_type | ||||
| from djangoql.schema import StrField | ||||
|  | ||||
|  | ||||
| class JSONSearchField(StrField): | ||||
|     """JSON field for DjangoQL""" | ||||
|  | ||||
|     model: Model | ||||
|  | ||||
|     def __init__(self, model=None, name=None, nullable=None, suggest_nested=True): | ||||
|         # Set this in the constructor to not clobber the type variable | ||||
|         self.type = "relation" | ||||
|         self.suggest_nested = suggest_nested | ||||
|         super().__init__(model, name, nullable) | ||||
|  | ||||
|     def get_lookup(self, path, operator, value): | ||||
|         search = "__".join(path) | ||||
|         op, invert = self.get_operator(operator) | ||||
|         q = Q(**{f"{search}{op}": self.get_lookup_value(value)}) | ||||
|         return ~q if invert else q | ||||
|  | ||||
|     def json_field_keys(self) -> Generator[tuple[str]]: | ||||
|         with connection.cursor() as cursor: | ||||
|             cursor.execute( | ||||
|                 f""" | ||||
|                 WITH RECURSIVE "{self.name}_keys" AS ( | ||||
|                     SELECT | ||||
|                         ARRAY[jsonb_object_keys("{self.name}")] AS key_path_array, | ||||
|                         "{self.name}" -> jsonb_object_keys("{self.name}") AS value | ||||
|                     FROM {self.model._meta.db_table} | ||||
|                     WHERE "{self.name}" IS NOT NULL | ||||
|                         AND jsonb_typeof("{self.name}") = 'object' | ||||
|  | ||||
|                     UNION ALL | ||||
|  | ||||
|                     SELECT | ||||
|                         ck.key_path_array || jsonb_object_keys(ck.value), | ||||
|                         ck.value -> jsonb_object_keys(ck.value) AS value | ||||
|                     FROM "{self.name}_keys" ck | ||||
|                     WHERE jsonb_typeof(ck.value) = 'object' | ||||
|                 ), | ||||
|  | ||||
|                 unique_paths AS ( | ||||
|                     SELECT DISTINCT key_path_array | ||||
|                     FROM "{self.name}_keys" | ||||
|                 ) | ||||
|  | ||||
|                 SELECT key_path_array FROM unique_paths; | ||||
|             """  # nosec | ||||
|             ) | ||||
|             return (x[0] for x in cursor.fetchall()) | ||||
|  | ||||
|     def get_nested_options(self) -> OrderedDict: | ||||
|         """Get keys of all nested objects to show autocomplete""" | ||||
|         if not self.suggest_nested: | ||||
|             return OrderedDict() | ||||
|         base_model_name = f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}" | ||||
|  | ||||
|         def recursive_function(parts: list[str], parent_parts: list[str] | None = None): | ||||
|             if not parent_parts: | ||||
|                 parent_parts = [] | ||||
|             path = parts.pop(0) | ||||
|             parent_parts.append(path) | ||||
|             relation_key = "_".join(parent_parts) | ||||
|             if len(parts) > 1: | ||||
|                 out_dict = { | ||||
|                     relation_key: { | ||||
|                         parts[0]: { | ||||
|                             "type": "relation", | ||||
|                             "relation": f"{relation_key}_{parts[0]}", | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|                 child_paths = recursive_function(parts.copy(), parent_parts.copy()) | ||||
|                 child_paths.update(out_dict) | ||||
|                 return child_paths | ||||
|             else: | ||||
|                 return {relation_key: {parts[0]: {}}} | ||||
|  | ||||
|         relation_structure = defaultdict(dict) | ||||
|  | ||||
|         for relations in self.json_field_keys(): | ||||
|             result = recursive_function([base_model_name] + relations) | ||||
|             for relation_key, value in result.items(): | ||||
|                 for sub_relation_key, sub_value in value.items(): | ||||
|                     if not relation_structure[relation_key].get(sub_relation_key, None): | ||||
|                         relation_structure[relation_key][sub_relation_key] = sub_value | ||||
|                     else: | ||||
|                         relation_structure[relation_key][sub_relation_key].update(sub_value) | ||||
|  | ||||
|         final_dict = defaultdict(dict) | ||||
|  | ||||
|         for key, value in relation_structure.items(): | ||||
|             for sub_key, sub_value in value.items(): | ||||
|                 if not sub_value: | ||||
|                     final_dict[key][sub_key] = { | ||||
|                         "type": "str", | ||||
|                         "nullable": True, | ||||
|                     } | ||||
|                 else: | ||||
|                     final_dict[key][sub_key] = sub_value | ||||
|         return OrderedDict(final_dict) | ||||
|  | ||||
|     def relation(self) -> str: | ||||
|         return f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}" | ||||
|  | ||||
|  | ||||
| class ChoiceSearchField(StrField): | ||||
|     def __init__(self, model=None, name=None, nullable=None): | ||||
|         super().__init__(model, name, nullable, suggest_options=True) | ||||
|  | ||||
|     def get_options(self, search): | ||||
|         result = [] | ||||
|         choices = self._field_choices() | ||||
|         if choices: | ||||
|             search = search.lower() | ||||
|             for c in choices: | ||||
|                 choice = text_type(c[0]) | ||||
|                 if search in choice.lower(): | ||||
|                     result.append(choice) | ||||
|         return result | ||||
| @ -1,53 +0,0 @@ | ||||
| from rest_framework.response import Response | ||||
|  | ||||
| from authentik.api.pagination import Pagination | ||||
| from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, QLSearch | ||||
|  | ||||
|  | ||||
| class AutocompletePagination(Pagination): | ||||
|  | ||||
|     def paginate_queryset(self, queryset, request, view=None): | ||||
|         self.view = view | ||||
|         return super().paginate_queryset(queryset, request, view) | ||||
|  | ||||
|     def get_autocomplete(self): | ||||
|         schema = QLSearch().get_schema(self.request, self.view) | ||||
|         introspections = {} | ||||
|         if hasattr(self.view, "get_ql_fields"): | ||||
|             from authentik.enterprise.search.schema import AKQLSchemaSerializer | ||||
|  | ||||
|             introspections = AKQLSchemaSerializer().serialize( | ||||
|                 schema(self.page.paginator.object_list.model) | ||||
|             ) | ||||
|         return introspections | ||||
|  | ||||
|     def get_paginated_response(self, data): | ||||
|         previous_page_number = 0 | ||||
|         if self.page.has_previous(): | ||||
|             previous_page_number = self.page.previous_page_number() | ||||
|         next_page_number = 0 | ||||
|         if self.page.has_next(): | ||||
|             next_page_number = self.page.next_page_number() | ||||
|         return Response( | ||||
|             { | ||||
|                 "pagination": { | ||||
|                     "next": next_page_number, | ||||
|                     "previous": previous_page_number, | ||||
|                     "count": self.page.paginator.count, | ||||
|                     "current": self.page.number, | ||||
|                     "total_pages": self.page.paginator.num_pages, | ||||
|                     "start_index": self.page.start_index(), | ||||
|                     "end_index": self.page.end_index(), | ||||
|                 }, | ||||
|                 "results": data, | ||||
|                 "autocomplete": self.get_autocomplete(), | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|     def get_paginated_response_schema(self, schema): | ||||
|         final_schema = super().get_paginated_response_schema(schema) | ||||
|         final_schema["properties"]["autocomplete"] = { | ||||
|             "$ref": f"#/components/schemas/{AUTOCOMPLETE_COMPONENT_NAME}" | ||||
|         } | ||||
|         final_schema["required"].append("autocomplete") | ||||
|         return final_schema | ||||
| @ -1,80 +0,0 @@ | ||||
| """DjangoQL search""" | ||||
|  | ||||
| from django.apps import apps | ||||
| from django.db.models import QuerySet | ||||
| from djangoql.ast import Name | ||||
| from djangoql.exceptions import DjangoQLError | ||||
| from djangoql.queryset import apply_search | ||||
| from djangoql.schema import DjangoQLSchema | ||||
| from rest_framework.filters import BaseFilterBackend, SearchFilter | ||||
| from rest_framework.request import Request | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.enterprise.search.fields import JSONSearchField | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| AUTOCOMPLETE_COMPONENT_NAME = "Autocomplete" | ||||
| AUTOCOMPLETE_SCHEMA = { | ||||
|     "type": "object", | ||||
|     "additionalProperties": {}, | ||||
| } | ||||
|  | ||||
|  | ||||
| class BaseSchema(DjangoQLSchema): | ||||
|     """Base Schema which deals with JSON Fields""" | ||||
|  | ||||
|     def resolve_name(self, name: Name): | ||||
|         model = self.model_label(self.current_model) | ||||
|         root_field = name.parts[0] | ||||
|         field = self.models[model].get(root_field) | ||||
|         # If the query goes into a JSON field, return the root | ||||
|         # field as the JSON field will do the rest | ||||
|         if isinstance(field, JSONSearchField): | ||||
|             # This is a workaround; build_filter will remove the right-most | ||||
|             # entry in the path as that is intended to be the same as the field | ||||
|             # however for JSON that is not the case | ||||
|             if name.parts[-1] != root_field: | ||||
|                 name.parts.append(root_field) | ||||
|             return field | ||||
|         return super().resolve_name(name) | ||||
|  | ||||
|  | ||||
| class QLSearch(BaseFilterBackend): | ||||
|     """rest_framework search filter which uses DjangoQL""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self._fallback = SearchFilter() | ||||
|  | ||||
|     @property | ||||
|     def enabled(self): | ||||
|         return apps.get_app_config("authentik_enterprise").enabled() | ||||
|  | ||||
|     def get_search_terms(self, request: Request) -> str: | ||||
|         """Search terms are set by a ?search=... query parameter, | ||||
|         and may be comma and/or whitespace delimited.""" | ||||
|         params = request.query_params.get("search", "") | ||||
|         params = params.replace("\x00", "")  # strip null characters | ||||
|         return params | ||||
|  | ||||
|     def get_schema(self, request: Request, view) -> BaseSchema: | ||||
|         ql_fields = [] | ||||
|         if hasattr(view, "get_ql_fields"): | ||||
|             ql_fields = view.get_ql_fields() | ||||
|  | ||||
|         class InlineSchema(BaseSchema): | ||||
|             def get_fields(self, model): | ||||
|                 return ql_fields or [] | ||||
|  | ||||
|         return InlineSchema | ||||
|  | ||||
|     def filter_queryset(self, request: Request, queryset: QuerySet, view) -> QuerySet: | ||||
|         search_query = self.get_search_terms(request) | ||||
|         schema = self.get_schema(request, view) | ||||
|         if len(search_query) == 0 or not self.enabled: | ||||
|             return self._fallback.filter_queryset(request, queryset, view) | ||||
|         try: | ||||
|             return apply_search(queryset, search_query, schema=schema) | ||||
|         except DjangoQLError as exc: | ||||
|             LOGGER.debug("Failed to parse search expression", exc=exc) | ||||
|             return self._fallback.filter_queryset(request, queryset, view) | ||||
| @ -1,29 +0,0 @@ | ||||
| from djangoql.serializers import DjangoQLSchemaSerializer | ||||
| from drf_spectacular.generators import SchemaGenerator | ||||
|  | ||||
| from authentik.api.schema import create_component | ||||
| from authentik.enterprise.search.fields import JSONSearchField | ||||
| from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA | ||||
|  | ||||
|  | ||||
| class AKQLSchemaSerializer(DjangoQLSchemaSerializer): | ||||
|     def serialize(self, schema): | ||||
|         serialization = super().serialize(schema) | ||||
|         for _, fields in schema.models.items(): | ||||
|             for _, field in fields.items(): | ||||
|                 if not isinstance(field, JSONSearchField): | ||||
|                     continue | ||||
|                 serialization["models"].update(field.get_nested_options()) | ||||
|         return serialization | ||||
|  | ||||
|     def serialize_field(self, field): | ||||
|         result = super().serialize_field(field) | ||||
|         if isinstance(field, JSONSearchField): | ||||
|             result["relation"] = field.relation() | ||||
|         return result | ||||
|  | ||||
|  | ||||
| def postprocess_schema_search_autocomplete(result, generator: SchemaGenerator, **kwargs): | ||||
|     create_component(generator, AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA) | ||||
|  | ||||
|     return result | ||||
| @ -1,17 +0,0 @@ | ||||
| SPECTACULAR_SETTINGS = { | ||||
|     "POSTPROCESSING_HOOKS": [ | ||||
|         "authentik.api.schema.postprocess_schema_responses", | ||||
|         "authentik.enterprise.search.schema.postprocess_schema_search_autocomplete", | ||||
|         "drf_spectacular.hooks.postprocess_schema_enums", | ||||
|     ], | ||||
| } | ||||
|  | ||||
| REST_FRAMEWORK = { | ||||
|     "DEFAULT_PAGINATION_CLASS": "authentik.enterprise.search.pagination.AutocompletePagination", | ||||
|     "DEFAULT_FILTER_BACKENDS": [ | ||||
|         "authentik.enterprise.search.ql.QLSearch", | ||||
|         "authentik.rbac.filters.ObjectFilter", | ||||
|         "django_filters.rest_framework.DjangoFilterBackend", | ||||
|         "rest_framework.filters.OrderingFilter", | ||||
|     ], | ||||
| } | ||||
| @ -1,78 +0,0 @@ | ||||
| from json import loads | ||||
| from unittest.mock import PropertyMock, patch | ||||
| from urllib.parse import urlencode | ||||
|  | ||||
| from django.urls import reverse | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
|  | ||||
|  | ||||
| @patch( | ||||
|     "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", | ||||
|     PropertyMock(return_value=True), | ||||
| ) | ||||
| class QLTest(APITestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         self.user = create_test_admin_user() | ||||
|         # ensure we have more than 1 user | ||||
|         create_test_admin_user() | ||||
|  | ||||
|     def test_search(self): | ||||
|         """Test simple search query""" | ||||
|         self.client.force_login(self.user) | ||||
|         query = f'username = "{self.user.username}"' | ||||
|         res = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_api:user-list", | ||||
|             ) | ||||
|             + f"?{urlencode({"search": query})}" | ||||
|         ) | ||||
|         self.assertEqual(res.status_code, 200) | ||||
|         content = loads(res.content) | ||||
|         self.assertEqual(content["pagination"]["count"], 1) | ||||
|         self.assertEqual(content["results"][0]["username"], self.user.username) | ||||
|  | ||||
|     def test_no_search(self): | ||||
|         """Ensure works with no search query""" | ||||
|         self.client.force_login(self.user) | ||||
|         res = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_api:user-list", | ||||
|             ) | ||||
|         ) | ||||
|         self.assertEqual(res.status_code, 200) | ||||
|         content = loads(res.content) | ||||
|         self.assertNotEqual(content["pagination"]["count"], 1) | ||||
|  | ||||
|     def test_search_no_ql(self): | ||||
|         """Test simple search query (no QL)""" | ||||
|         self.client.force_login(self.user) | ||||
|         res = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_api:user-list", | ||||
|             ) | ||||
|             + f"?{urlencode({"search": self.user.username})}" | ||||
|         ) | ||||
|         self.assertEqual(res.status_code, 200) | ||||
|         content = loads(res.content) | ||||
|         self.assertEqual(content["pagination"]["count"], 1) | ||||
|         self.assertEqual(content["results"][0]["username"], self.user.username) | ||||
|  | ||||
|     def test_search_json(self): | ||||
|         """Test search query with a JSON attribute""" | ||||
|         self.user.attributes = {"foo": {"bar": "baz"}} | ||||
|         self.user.save() | ||||
|         self.client.force_login(self.user) | ||||
|         query = 'attributes.foo.bar = "baz"' | ||||
|         res = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_api:user-list", | ||||
|             ) | ||||
|             + f"?{urlencode({"search": query})}" | ||||
|         ) | ||||
|         self.assertEqual(res.status_code, 200) | ||||
|         content = loads(res.content) | ||||
|         self.assertEqual(content["pagination"]["count"], 1) | ||||
|         self.assertEqual(content["results"][0]["username"], self.user.username) | ||||
| @ -18,7 +18,6 @@ TENANT_APPS = [ | ||||
|     "authentik.enterprise.providers.google_workspace", | ||||
|     "authentik.enterprise.providers.microsoft_entra", | ||||
|     "authentik.enterprise.providers.ssf", | ||||
|     "authentik.enterprise.search", | ||||
|     "authentik.enterprise.stages.authenticator_endpoint_gdtc", | ||||
|     "authentik.enterprise.stages.mtls", | ||||
|     "authentik.enterprise.stages.source", | ||||
|  | ||||
| @ -97,7 +97,6 @@ class SourceStageFinal(StageView): | ||||
|         token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN) | ||||
|         self.logger.info("Replacing source flow with overridden flow", flow=token.flow.slug) | ||||
|         plan = token.plan | ||||
|         plan.context.update(self.executor.plan.context) | ||||
|         plan.context[PLAN_CONTEXT_IS_RESTORED] = token | ||||
|         response = plan.to_redirect(self.request, token.flow) | ||||
|         token.delete() | ||||
|  | ||||
| @ -90,17 +90,14 @@ class TestSourceStage(FlowTestCase): | ||||
|         plan: FlowPlan = session[SESSION_KEY_PLAN] | ||||
|         plan.insert_stage(in_memory_stage(SourceStageFinal), index=0) | ||||
|         plan.context[PLAN_CONTEXT_IS_RESTORED] = flow_token | ||||
|         plan.context["foo"] = "bar" | ||||
|         session[SESSION_KEY_PLAN] = plan | ||||
|         session.save() | ||||
|  | ||||
|         # Pretend we've just returned from the source | ||||
|         with self.assertFlowFinishes() as ff: | ||||
|             response = self.client.get( | ||||
|                 reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True | ||||
|             ) | ||||
|             self.assertEqual(response.status_code, 200) | ||||
|             self.assertStageRedirects( | ||||
|                 response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug}) | ||||
|             ) | ||||
|         self.assertEqual(ff().context["foo"], "bar") | ||||
|         response = self.client.get( | ||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertStageRedirects( | ||||
|             response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug}) | ||||
|         ) | ||||
|  | ||||
| @ -132,22 +132,6 @@ class EventViewSet(ModelViewSet): | ||||
|     ] | ||||
|     filterset_class = EventsFilter | ||||
|  | ||||
|     def get_ql_fields(self): | ||||
|         from djangoql.schema import DateTimeField, StrField | ||||
|  | ||||
|         from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField | ||||
|  | ||||
|         return [ | ||||
|             ChoiceSearchField(Event, "action"), | ||||
|             StrField(Event, "event_uuid"), | ||||
|             StrField(Event, "app", suggest_options=True), | ||||
|             StrField(Event, "client_ip"), | ||||
|             JSONSearchField(Event, "user", suggest_nested=False), | ||||
|             JSONSearchField(Event, "brand", suggest_nested=False), | ||||
|             JSONSearchField(Event, "context", suggest_nested=False), | ||||
|             DateTimeField(Event, "created", suggest_options=True), | ||||
|         ] | ||||
|  | ||||
|     @extend_schema( | ||||
|         methods=["GET"], | ||||
|         responses={200: EventTopPerUserSerializer(many=True)}, | ||||
|  | ||||
| @ -11,7 +11,7 @@ from authentik.events.models import NotificationRule | ||||
| class NotificationRuleSerializer(ModelSerializer): | ||||
|     """NotificationRule Serializer""" | ||||
|  | ||||
|     destination_group_obj = GroupSerializer(read_only=True, source="destination_group") | ||||
|     group_obj = GroupSerializer(read_only=True, source="group") | ||||
|  | ||||
|     class Meta: | ||||
|         model = NotificationRule | ||||
| @ -20,9 +20,8 @@ class NotificationRuleSerializer(ModelSerializer): | ||||
|             "name", | ||||
|             "transports", | ||||
|             "severity", | ||||
|             "destination_group", | ||||
|             "destination_group_obj", | ||||
|             "destination_event_user", | ||||
|             "group", | ||||
|             "group_obj", | ||||
|         ] | ||||
|  | ||||
|  | ||||
| @ -31,6 +30,6 @@ class NotificationRuleViewSet(UsedByMixin, ModelViewSet): | ||||
|  | ||||
|     queryset = NotificationRule.objects.all() | ||||
|     serializer_class = NotificationRuleSerializer | ||||
|     filterset_fields = ["name", "severity", "destination_group__name"] | ||||
|     filterset_fields = ["name", "severity", "group__name"] | ||||
|     ordering = ["name"] | ||||
|     search_fields = ["name", "destination_group__name"] | ||||
|     search_fields = ["name", "group__name"] | ||||
|  | ||||
| @ -15,13 +15,13 @@ class MMDBContextProcessor(EventContextProcessor): | ||||
|         self.reader: Reader | None = None | ||||
|         self._last_mtime: float = 0.0 | ||||
|         self.logger = get_logger() | ||||
|         self.load() | ||||
|         self.open() | ||||
|  | ||||
|     def path(self) -> str | None: | ||||
|         """Get the path to the MMDB file to load""" | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def load(self): | ||||
|     def open(self): | ||||
|         """Get GeoIP Reader, if configured, otherwise none""" | ||||
|         path = self.path() | ||||
|         if path == "" or not path: | ||||
| @ -44,7 +44,7 @@ class MMDBContextProcessor(EventContextProcessor): | ||||
|             diff = self._last_mtime < mtime | ||||
|             if diff > 0: | ||||
|                 self.logger.info("Found new MMDB Database, reopening", diff=diff, path=path) | ||||
|                 self.load() | ||||
|                 self.open() | ||||
|         except OSError as exc: | ||||
|             self.logger.warning("Failed to check MMDB age", exc=exc) | ||||
|  | ||||
|  | ||||
| @ -19,7 +19,7 @@ from authentik.blueprints.v1.importer import excluded_models | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.events.models import Event, EventAction, Notification | ||||
| from authentik.events.utils import model_to_dict | ||||
| from authentik.lib.sentry import should_ignore_exception | ||||
| from authentik.lib.sentry import before_send | ||||
| from authentik.lib.utils.errors import exception_to_string | ||||
| from authentik.stages.authenticator_static.models import StaticToken | ||||
|  | ||||
| @ -173,7 +173,7 @@ class AuditMiddleware: | ||||
|                 message=exception_to_string(exception), | ||||
|             ) | ||||
|             thread.run() | ||||
|         elif not should_ignore_exception(exception): | ||||
|         elif before_send({}, {"exc_info": (None, exception, None)}) is not None: | ||||
|             thread = EventNewThread( | ||||
|                 EventAction.SYSTEM_EXCEPTION, | ||||
|                 request, | ||||
|  | ||||
| @ -1,26 +0,0 @@ | ||||
| # Generated by Django 5.1.11 on 2025-06-16 23:21 | ||||
|  | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_events", "0009_remove_notificationtransport_webhook_mapping_and_more"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.RenameField( | ||||
|             model_name="notificationrule", | ||||
|             old_name="group", | ||||
|             new_name="destination_group", | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="notificationrule", | ||||
|             name="destination_event_user", | ||||
|             field=models.BooleanField( | ||||
|                 default=False, | ||||
|                 help_text="When enabled, notification will be sent to user the user that triggered the event.When destination_group is configured, notification is sent to both.", | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
| @ -1,12 +1,10 @@ | ||||
| """authentik events models""" | ||||
|  | ||||
| from collections.abc import Generator | ||||
| from datetime import timedelta | ||||
| from difflib import get_close_matches | ||||
| from functools import lru_cache | ||||
| from inspect import currentframe | ||||
| from smtplib import SMTPException | ||||
| from typing import Any | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from django.apps import apps | ||||
| @ -193,32 +191,17 @@ class Event(SerializerModel, ExpiringModel): | ||||
|             brand: Brand = request.brand | ||||
|             self.brand = sanitize_dict(model_to_dict(brand)) | ||||
|         if hasattr(request, "user"): | ||||
|             self.user = get_user(request.user) | ||||
|             original_user = None | ||||
|             if hasattr(request, "session"): | ||||
|                 original_user = request.session.get(SESSION_KEY_IMPERSONATE_ORIGINAL_USER, None) | ||||
|             self.user = get_user(request.user, original_user) | ||||
|         if user: | ||||
|             self.user = get_user(user) | ||||
|         # Check if we're currently impersonating, and add that user | ||||
|         if hasattr(request, "session"): | ||||
|             from authentik.flows.views.executor import SESSION_KEY_PLAN | ||||
|  | ||||
|             # Check if we're currently impersonating, and add that user | ||||
|             if SESSION_KEY_IMPERSONATE_ORIGINAL_USER in request.session: | ||||
|                 self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER]) | ||||
|                 self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER]) | ||||
|             # Special case for events that happen during a flow, the user might not be authenticated | ||||
|             # yet but is a pending user instead | ||||
|             if SESSION_KEY_PLAN in request.session: | ||||
|                 from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan | ||||
|  | ||||
|                 plan: FlowPlan = request.session[SESSION_KEY_PLAN] | ||||
|                 pending_user = plan.context.get(PLAN_CONTEXT_PENDING_USER, None) | ||||
|                 # Only save `authenticated_as` if there's a different pending user in the flow | ||||
|                 # than the user that is authenticated | ||||
|                 if pending_user and ( | ||||
|                     (pending_user.pk and pending_user.pk != self.user.get("pk")) | ||||
|                     or (not pending_user.pk) | ||||
|                 ): | ||||
|                     orig_user = self.user.copy() | ||||
|  | ||||
|                     self.user = {"authenticated_as": orig_user, **get_user(pending_user)} | ||||
|         # User 255.255.255.255 as fallback if IP cannot be determined | ||||
|         self.client_ip = ClientIPMiddleware.get_client_ip(request) | ||||
|         # Enrich event data | ||||
| @ -564,7 +547,7 @@ class NotificationRule(SerializerModel, PolicyBindingModel): | ||||
|         default=NotificationSeverity.NOTICE, | ||||
|         help_text=_("Controls which severity level the created notifications will have."), | ||||
|     ) | ||||
|     destination_group = models.ForeignKey( | ||||
|     group = models.ForeignKey( | ||||
|         Group, | ||||
|         help_text=_( | ||||
|             "Define which group of users this notification should be sent and shown to. " | ||||
| @ -574,19 +557,6 @@ class NotificationRule(SerializerModel, PolicyBindingModel): | ||||
|         blank=True, | ||||
|         on_delete=models.SET_NULL, | ||||
|     ) | ||||
|     destination_event_user = models.BooleanField( | ||||
|         default=False, | ||||
|         help_text=_( | ||||
|             "When enabled, notification will be sent to user the user that triggered the event." | ||||
|             "When destination_group is configured, notification is sent to both." | ||||
|         ), | ||||
|     ) | ||||
|  | ||||
|     def destination_users(self, event: Event) -> Generator[User, Any]: | ||||
|         if self.destination_event_user and event.user.get("pk"): | ||||
|             yield User(pk=event.user.get("pk")) | ||||
|         if self.destination_group: | ||||
|             yield from self.destination_group.users.all() | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> type[Serializer]: | ||||
|  | ||||
| @ -68,10 +68,14 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): | ||||
|     if not result.passing: | ||||
|         return | ||||
|  | ||||
|     if not trigger.group: | ||||
|         LOGGER.debug("e(trigger): trigger has no group", trigger=trigger) | ||||
|         return | ||||
|  | ||||
|     LOGGER.debug("e(trigger): event trigger matched", trigger=trigger) | ||||
|     # Create the notification objects | ||||
|     for transport in trigger.transports.all(): | ||||
|         for user in trigger.destination_users(event): | ||||
|         for user in trigger.group.users.all(): | ||||
|             LOGGER.debug("created notification") | ||||
|             notification_transport.apply_async( | ||||
|                 args=[ | ||||
|  | ||||
| @ -2,9 +2,7 @@ | ||||
|  | ||||
| from django.test import TestCase | ||||
|  | ||||
| from authentik.events.context_processors.base import get_context_processors | ||||
| from authentik.events.context_processors.geoip import GeoIPContextProcessor | ||||
| from authentik.events.models import Event, EventAction | ||||
|  | ||||
|  | ||||
| class TestGeoIP(TestCase): | ||||
| @ -15,7 +13,8 @@ class TestGeoIP(TestCase): | ||||
|  | ||||
|     def test_simple(self): | ||||
|         """Test simple city wrapper""" | ||||
|         # IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json | ||||
|         # IPs from | ||||
|         # https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json | ||||
|         self.assertEqual( | ||||
|             self.reader.city_dict("2.125.160.216"), | ||||
|             { | ||||
| @ -26,12 +25,3 @@ class TestGeoIP(TestCase): | ||||
|                 "long": -1.25, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_special_chars(self): | ||||
|         """Test city name with special characters""" | ||||
|         # IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json | ||||
|         event = Event.new(EventAction.LOGIN) | ||||
|         event.client_ip = "89.160.20.112" | ||||
|         for processor in get_context_processors(): | ||||
|             processor.enrich_event(event) | ||||
|         event.save() | ||||
|  | ||||
| @ -8,11 +8,9 @@ from django.views.debug import SafeExceptionReporterFilter | ||||
| from guardian.shortcuts import get_anonymous_user | ||||
|  | ||||
| from authentik.brands.models import Brand | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.core.tests.utils import create_test_user | ||||
| from authentik.core.models import Group | ||||
| from authentik.events.models import Event | ||||
| from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan | ||||
| from authentik.flows.views.executor import QS_QUERY, SESSION_KEY_PLAN | ||||
| from authentik.flows.views.executor import QS_QUERY | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.policies.dummy.models import DummyPolicy | ||||
|  | ||||
| @ -118,92 +116,3 @@ class TestEvents(TestCase): | ||||
|                 "pk": brand.pk.hex, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_from_http_flow_pending_user(self): | ||||
|         """Test request from flow request with a pending user""" | ||||
|         user = create_test_user() | ||||
|  | ||||
|         session = self.client.session | ||||
|         plan = FlowPlan(generate_id()) | ||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = user | ||||
|         session[SESSION_KEY_PLAN] = plan | ||||
|         session.save() | ||||
|  | ||||
|         request = self.factory.get("/") | ||||
|         request.session = session | ||||
|         request.user = user | ||||
|  | ||||
|         event = Event.new("unittest").from_http(request) | ||||
|         self.assertEqual( | ||||
|             event.user, | ||||
|             { | ||||
|                 "email": user.email, | ||||
|                 "pk": user.pk, | ||||
|                 "username": user.username, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_from_http_flow_pending_user_anon(self): | ||||
|         """Test request from flow request with a pending user""" | ||||
|         user = create_test_user() | ||||
|         anon = get_anonymous_user() | ||||
|  | ||||
|         session = self.client.session | ||||
|         plan = FlowPlan(generate_id()) | ||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = user | ||||
|         session[SESSION_KEY_PLAN] = plan | ||||
|         session.save() | ||||
|  | ||||
|         request = self.factory.get("/") | ||||
|         request.session = session | ||||
|         request.user = anon | ||||
|  | ||||
|         event = Event.new("unittest").from_http(request) | ||||
|         self.assertEqual( | ||||
|             event.user, | ||||
|             { | ||||
|                 "authenticated_as": { | ||||
|                     "pk": anon.pk, | ||||
|                     "is_anonymous": True, | ||||
|                     "username": "AnonymousUser", | ||||
|                     "email": "", | ||||
|                 }, | ||||
|                 "email": user.email, | ||||
|                 "pk": user.pk, | ||||
|                 "username": user.username, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_from_http_flow_pending_user_fake(self): | ||||
|         """Test request from flow request with a pending user""" | ||||
|         user = User( | ||||
|             username=generate_id(), | ||||
|             email=generate_id(), | ||||
|         ) | ||||
|         anon = get_anonymous_user() | ||||
|  | ||||
|         session = self.client.session | ||||
|         plan = FlowPlan(generate_id()) | ||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = user | ||||
|         session[SESSION_KEY_PLAN] = plan | ||||
|         session.save() | ||||
|  | ||||
|         request = self.factory.get("/") | ||||
|         request.session = session | ||||
|         request.user = anon | ||||
|  | ||||
|         event = Event.new("unittest").from_http(request) | ||||
|         self.assertEqual( | ||||
|             event.user, | ||||
|             { | ||||
|                 "authenticated_as": { | ||||
|                     "pk": anon.pk, | ||||
|                     "is_anonymous": True, | ||||
|                     "username": "AnonymousUser", | ||||
|                     "email": "", | ||||
|                 }, | ||||
|                 "email": user.email, | ||||
|                 "pk": user.pk, | ||||
|                 "username": user.username, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
| @ -6,7 +6,6 @@ from django.urls import reverse | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.core.tests.utils import create_test_user | ||||
| from authentik.events.models import ( | ||||
|     Event, | ||||
|     EventAction, | ||||
| @ -35,7 +34,7 @@ class TestEventsNotifications(APITestCase): | ||||
|     def test_trigger_empty(self): | ||||
|         """Test trigger without any policies attached""" | ||||
|         transport = NotificationTransport.objects.create(name=generate_id()) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||
|         trigger.transports.add(transport) | ||||
|         trigger.save() | ||||
|  | ||||
| @ -47,7 +46,7 @@ class TestEventsNotifications(APITestCase): | ||||
|     def test_trigger_single(self): | ||||
|         """Test simple transport triggering""" | ||||
|         transport = NotificationTransport.objects.create(name=generate_id()) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||
|         trigger.transports.add(transport) | ||||
|         trigger.save() | ||||
|         matcher = EventMatcherPolicy.objects.create( | ||||
| @ -60,25 +59,6 @@ class TestEventsNotifications(APITestCase): | ||||
|             Event.new(EventAction.CUSTOM_PREFIX).save() | ||||
|         self.assertEqual(execute_mock.call_count, 1) | ||||
|  | ||||
|     def test_trigger_event_user(self): | ||||
|         """Test trigger with event user""" | ||||
|         user = create_test_user() | ||||
|         transport = NotificationTransport.objects.create(name=generate_id()) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_event_user=True) | ||||
|         trigger.transports.add(transport) | ||||
|         trigger.save() | ||||
|         matcher = EventMatcherPolicy.objects.create( | ||||
|             name="matcher", action=EventAction.CUSTOM_PREFIX | ||||
|         ) | ||||
|         PolicyBinding.objects.create(target=trigger, policy=matcher, order=0) | ||||
|  | ||||
|         execute_mock = MagicMock() | ||||
|         with patch("authentik.events.models.NotificationTransport.send", execute_mock): | ||||
|             Event.new(EventAction.CUSTOM_PREFIX).set_user(user).save() | ||||
|         self.assertEqual(execute_mock.call_count, 1) | ||||
|         notification: Notification = execute_mock.call_args[0][0] | ||||
|         self.assertEqual(notification.user, user) | ||||
|  | ||||
|     def test_trigger_no_group(self): | ||||
|         """Test trigger without group""" | ||||
|         trigger = NotificationRule.objects.create(name=generate_id()) | ||||
| @ -96,7 +76,7 @@ class TestEventsNotifications(APITestCase): | ||||
|         """Test Policy error which would cause recursion""" | ||||
|         transport = NotificationTransport.objects.create(name=generate_id()) | ||||
|         NotificationRule.objects.filter(name__startswith="default").delete() | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||
|         trigger.transports.add(transport) | ||||
|         trigger.save() | ||||
|         matcher = EventMatcherPolicy.objects.create( | ||||
| @ -119,7 +99,7 @@ class TestEventsNotifications(APITestCase): | ||||
|  | ||||
|         transport = NotificationTransport.objects.create(name=generate_id(), send_once=True) | ||||
|         NotificationRule.objects.filter(name__startswith="default").delete() | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||
|         trigger.transports.add(transport) | ||||
|         trigger.save() | ||||
|         matcher = EventMatcherPolicy.objects.create( | ||||
| @ -143,7 +123,7 @@ class TestEventsNotifications(APITestCase): | ||||
|             name=generate_id(), webhook_mapping_body=mapping, mode=TransportMode.LOCAL | ||||
|         ) | ||||
|         NotificationRule.objects.filter(name__startswith="default").delete() | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||
|         trigger.transports.add(transport) | ||||
|         matcher = EventMatcherPolicy.objects.create( | ||||
|             name="matcher", action=EventAction.CUSTOM_PREFIX | ||||
|  | ||||
| @ -74,8 +74,8 @@ def model_to_dict(model: Model) -> dict[str, Any]: | ||||
|     } | ||||
|  | ||||
|  | ||||
| def get_user(user: User | AnonymousUser) -> dict[str, Any]: | ||||
|     """Convert user object to dictionary""" | ||||
| def get_user(user: User | AnonymousUser, original_user: User | None = None) -> dict[str, Any]: | ||||
|     """Convert user object to dictionary, optionally including the original user""" | ||||
|     if isinstance(user, AnonymousUser): | ||||
|         try: | ||||
|             user = get_anonymous_user() | ||||
| @ -88,6 +88,10 @@ def get_user(user: User | AnonymousUser) -> dict[str, Any]: | ||||
|     } | ||||
|     if user.username == settings.ANONYMOUS_USER_NAME: | ||||
|         user_data["is_anonymous"] = True | ||||
|     if original_user: | ||||
|         original_data = get_user(original_user) | ||||
|         original_data["on_behalf_of"] = user_data | ||||
|         return original_data | ||||
|     return user_data | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -4,10 +4,8 @@ from unittest.mock import MagicMock, PropertyMock, patch | ||||
| from urllib.parse import urlencode | ||||
|  | ||||
| from django.http import HttpRequest, HttpResponse | ||||
| from django.test import override_settings | ||||
| from django.test.client import RequestFactory | ||||
| from django.urls import reverse | ||||
| from rest_framework.exceptions import ParseError | ||||
|  | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.core.tests.utils import create_test_flow, create_test_user | ||||
| @ -650,25 +648,3 @@ class TestFlowExecutor(FlowTestCase): | ||||
|             self.assertStageResponse(response, flow, component="ak-stage-identification") | ||||
|             response = self.client.post(exec_url, {"uid_field": user_other.username}, follow=True) | ||||
|             self.assertStageResponse(response, flow, component="ak-stage-access-denied") | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.flows.views.executor.to_stage_response", | ||||
|         TO_STAGE_RESPONSE_MOCK, | ||||
|     ) | ||||
|     def test_invalid_json(self): | ||||
|         """Test invalid JSON body""" | ||||
|         flow = create_test_flow() | ||||
|         FlowStageBinding.objects.create( | ||||
|             target=flow, stage=DummyStage.objects.create(name=generate_id()), order=0 | ||||
|         ) | ||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||
|  | ||||
|         with override_settings(TEST=False, DEBUG=False): | ||||
|             self.client.logout() | ||||
|             response = self.client.post(url, data="{", content_type="application/json") | ||||
|             self.assertEqual(response.status_code, 200) | ||||
|  | ||||
|         with self.assertRaises(ParseError): | ||||
|             self.client.logout() | ||||
|             response = self.client.post(url, data="{", content_type="application/json") | ||||
|             self.assertEqual(response.status_code, 200) | ||||
|  | ||||
| @ -55,7 +55,7 @@ from authentik.flows.planner import ( | ||||
|     FlowPlanner, | ||||
| ) | ||||
| from authentik.flows.stage import AccessDeniedStage, StageView | ||||
| from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception | ||||
| from authentik.lib.sentry import SentryIgnoredException | ||||
| from authentik.lib.utils.errors import exception_to_string | ||||
| from authentik.lib.utils.reflection import all_subclasses, class_to_path | ||||
| from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs | ||||
| @ -234,13 +234,12 @@ class FlowExecutorView(APIView): | ||||
|         """Handle exception in stage execution""" | ||||
|         if settings.DEBUG or settings.TEST: | ||||
|             raise exc | ||||
|         capture_exception(exc) | ||||
|         self._logger.warning(exc) | ||||
|         if not should_ignore_exception(exc): | ||||
|             capture_exception(exc) | ||||
|             Event.new( | ||||
|                 action=EventAction.SYSTEM_EXCEPTION, | ||||
|                 message=exception_to_string(exc), | ||||
|             ).from_http(self.request) | ||||
|         Event.new( | ||||
|             action=EventAction.SYSTEM_EXCEPTION, | ||||
|             message=exception_to_string(exc), | ||||
|         ).from_http(self.request) | ||||
|         challenge = FlowErrorChallenge(self.request, exc) | ||||
|         challenge.is_valid(raise_exception=True) | ||||
|         return to_stage_response(self.request, HttpChallengeResponse(challenge)) | ||||
|  | ||||
| @ -14,7 +14,6 @@ from django_redis.exceptions import ConnectionInterrupted | ||||
| from docker.errors import DockerException | ||||
| from h11 import LocalProtocolError | ||||
| from ldap3.core.exceptions import LDAPException | ||||
| from psycopg.errors import Error | ||||
| from redis.exceptions import ConnectionError as RedisConnectionError | ||||
| from redis.exceptions import RedisError, ResponseError | ||||
| from rest_framework.exceptions import APIException | ||||
| @ -45,49 +44,6 @@ class SentryIgnoredException(Exception): | ||||
|     """Base Class for all errors that are suppressed, and not sent to sentry.""" | ||||
|  | ||||
|  | ||||
| ignored_classes = ( | ||||
|     # Inbuilt types | ||||
|     KeyboardInterrupt, | ||||
|     ConnectionResetError, | ||||
|     OSError, | ||||
|     PermissionError, | ||||
|     # Django Errors | ||||
|     Error, | ||||
|     ImproperlyConfigured, | ||||
|     DatabaseError, | ||||
|     OperationalError, | ||||
|     InternalError, | ||||
|     ProgrammingError, | ||||
|     SuspiciousOperation, | ||||
|     ValidationError, | ||||
|     # Redis errors | ||||
|     RedisConnectionError, | ||||
|     ConnectionInterrupted, | ||||
|     RedisError, | ||||
|     ResponseError, | ||||
|     # websocket errors | ||||
|     ChannelFull, | ||||
|     WebSocketException, | ||||
|     LocalProtocolError, | ||||
|     # rest_framework error | ||||
|     APIException, | ||||
|     # celery errors | ||||
|     WorkerLostError, | ||||
|     CeleryError, | ||||
|     SoftTimeLimitExceeded, | ||||
|     # custom baseclass | ||||
|     SentryIgnoredException, | ||||
|     # ldap errors | ||||
|     LDAPException, | ||||
|     # Docker errors | ||||
|     DockerException, | ||||
|     # End-user errors | ||||
|     Http404, | ||||
|     # AsyncIO | ||||
|     CancelledError, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class SentryTransport(HttpTransport): | ||||
|     """Custom sentry transport with custom user-agent""" | ||||
|  | ||||
| @ -145,17 +101,56 @@ def traces_sampler(sampling_context: dict) -> float: | ||||
|     return float(CONFIG.get("error_reporting.sample_rate", 0.1)) | ||||
|  | ||||
|  | ||||
| def should_ignore_exception(exc: Exception) -> bool: | ||||
|     """Check if an exception should be dropped""" | ||||
|     return isinstance(exc, ignored_classes) | ||||
|  | ||||
|  | ||||
| def before_send(event: dict, hint: dict) -> dict | None: | ||||
|     """Check if error is database error, and ignore if so""" | ||||
|  | ||||
|     from psycopg.errors import Error | ||||
|  | ||||
|     ignored_classes = ( | ||||
|         # Inbuilt types | ||||
|         KeyboardInterrupt, | ||||
|         ConnectionResetError, | ||||
|         OSError, | ||||
|         PermissionError, | ||||
|         # Django Errors | ||||
|         Error, | ||||
|         ImproperlyConfigured, | ||||
|         DatabaseError, | ||||
|         OperationalError, | ||||
|         InternalError, | ||||
|         ProgrammingError, | ||||
|         SuspiciousOperation, | ||||
|         ValidationError, | ||||
|         # Redis errors | ||||
|         RedisConnectionError, | ||||
|         ConnectionInterrupted, | ||||
|         RedisError, | ||||
|         ResponseError, | ||||
|         # websocket errors | ||||
|         ChannelFull, | ||||
|         WebSocketException, | ||||
|         LocalProtocolError, | ||||
|         # rest_framework error | ||||
|         APIException, | ||||
|         # celery errors | ||||
|         WorkerLostError, | ||||
|         CeleryError, | ||||
|         SoftTimeLimitExceeded, | ||||
|         # custom baseclass | ||||
|         SentryIgnoredException, | ||||
|         # ldap errors | ||||
|         LDAPException, | ||||
|         # Docker errors | ||||
|         DockerException, | ||||
|         # End-user errors | ||||
|         Http404, | ||||
|         # AsyncIO | ||||
|         CancelledError, | ||||
|     ) | ||||
|     exc_value = None | ||||
|     if "exc_info" in hint: | ||||
|         _, exc_value, _ = hint["exc_info"] | ||||
|         if should_ignore_exception(exc_value): | ||||
|         if isinstance(exc_value, ignored_classes): | ||||
|             LOGGER.debug("dropping exception", exc=exc_value) | ||||
|             return None | ||||
|     if "logger" in event: | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
|  | ||||
| from django.test import TestCase | ||||
|  | ||||
| from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception | ||||
| from authentik.lib.sentry import SentryIgnoredException, before_send | ||||
|  | ||||
|  | ||||
| class TestSentry(TestCase): | ||||
| @ -10,8 +10,8 @@ class TestSentry(TestCase): | ||||
|  | ||||
|     def test_error_not_sent(self): | ||||
|         """Test SentryIgnoredError not sent""" | ||||
|         self.assertTrue(should_ignore_exception(SentryIgnoredException())) | ||||
|         self.assertIsNone(before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)})) | ||||
|  | ||||
|     def test_error_sent(self): | ||||
|         """Test error sent""" | ||||
|         self.assertFalse(should_ignore_exception(ValueError())) | ||||
|         self.assertEqual({}, before_send({}, {"exc_info": (0, ValueError(), 0)})) | ||||
|  | ||||
| @ -1,13 +1,15 @@ | ||||
| """authentik outpost signals""" | ||||
|  | ||||
| from django.contrib.auth.signals import user_logged_out | ||||
| from django.core.cache import cache | ||||
| from django.db.models import Model | ||||
| from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save | ||||
| from django.dispatch import receiver | ||||
| from django.http import HttpRequest | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.brands.models import Brand | ||||
| from authentik.core.models import AuthenticatedSession, Provider | ||||
| from authentik.core.models import AuthenticatedSession, Provider, User | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.lib.utils.reflection import class_to_path | ||||
| from authentik.outposts.models import Outpost, OutpostServiceConnection | ||||
| @ -80,6 +82,14 @@ def pre_delete_cleanup(sender, instance: Outpost, **_): | ||||
|     outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) | ||||
|  | ||||
|  | ||||
| @receiver(user_logged_out) | ||||
| def logout_revoke_direct(sender: type[User], request: HttpRequest, **_): | ||||
|     """Catch logout by direct logout and forward to providers""" | ||||
|     if not request.session or not request.session.session_key: | ||||
|         return | ||||
|     outpost_session_end.delay(request.session.session_key) | ||||
|  | ||||
|  | ||||
| @receiver(pre_delete, sender=AuthenticatedSession) | ||||
| def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | ||||
|     """Catch logout by expiring sessions being deleted""" | ||||
|  | ||||
| @ -1,10 +1,23 @@ | ||||
| from django.contrib.auth.signals import user_logged_out | ||||
| from django.db.models.signals import post_save, pre_delete | ||||
| from django.dispatch import receiver | ||||
| from django.http import HttpRequest | ||||
|  | ||||
| from authentik.core.models import AuthenticatedSession, User | ||||
| from authentik.providers.oauth2.models import AccessToken, DeviceToken, RefreshToken | ||||
|  | ||||
|  | ||||
| @receiver(user_logged_out) | ||||
| def user_logged_out_oauth_tokens_removal(sender, request: HttpRequest, user: User, **_): | ||||
|     """Revoke tokens upon user logout""" | ||||
|     if not request.session or not request.session.session_key: | ||||
|         return | ||||
|     AccessToken.objects.filter( | ||||
|         user=user, | ||||
|         session__session__session_key=request.session.session_key, | ||||
|     ).delete() | ||||
|  | ||||
|  | ||||
| @receiver(pre_delete, sender=AuthenticatedSession) | ||||
| def user_session_deleted_oauth_tokens_removal(sender, instance: AuthenticatedSession, **_): | ||||
|     """Revoke tokens upon user logout""" | ||||
|  | ||||
| @ -387,7 +387,8 @@ class TestAuthorize(OAuthTestCase): | ||||
|             self.assertEqual( | ||||
|                 response.url, | ||||
|                 ( | ||||
|                     f"http://localhost#id_token={provider.encode(token.id_token.to_dict())}" | ||||
|                     f"http://localhost#access_token={token.token}" | ||||
|                     f"&id_token={provider.encode(token.id_token.to_dict())}" | ||||
|                     f"&token_type={TOKEN_TYPE}" | ||||
|                     f"&expires_in={int(expires)}&state={state}" | ||||
|                 ), | ||||
| @ -562,6 +563,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|                 "url": "http://localhost", | ||||
|                 "title": f"Redirecting to {app.name}...", | ||||
|                 "attrs": { | ||||
|                     "access_token": token.token, | ||||
|                     "id_token": provider.encode(token.id_token.to_dict()), | ||||
|                     "token_type": TOKEN_TYPE, | ||||
|                     "expires_in": "3600", | ||||
|  | ||||
| @ -150,12 +150,12 @@ class OAuthAuthorizationParams: | ||||
|         self.check_redirect_uri() | ||||
|         self.check_grant() | ||||
|         self.check_scope(github_compat) | ||||
|         self.check_nonce() | ||||
|         self.check_code_challenge() | ||||
|         if self.request: | ||||
|             raise AuthorizeError( | ||||
|                 self.redirect_uri, "request_not_supported", self.grant_type, self.state | ||||
|             ) | ||||
|         self.check_nonce() | ||||
|         self.check_code_challenge() | ||||
|  | ||||
|     def check_grant(self): | ||||
|         """Check grant""" | ||||
| @ -630,6 +630,7 @@ class OAuthFulfillmentStage(StageView): | ||||
|         if self.params.response_type in [ | ||||
|             ResponseTypes.ID_TOKEN_TOKEN, | ||||
|             ResponseTypes.CODE_ID_TOKEN_TOKEN, | ||||
|             ResponseTypes.ID_TOKEN, | ||||
|             ResponseTypes.CODE_TOKEN, | ||||
|         ]: | ||||
|             query_fragment["access_token"] = token.token | ||||
|  | ||||
| @ -66,10 +66,7 @@ class RACClientConsumer(AsyncWebsocketConsumer): | ||||
|     def init_outpost_connection(self): | ||||
|         """Initialize guac connection settings""" | ||||
|         self.token = ( | ||||
|             ConnectionToken.filter_not_expired( | ||||
|                 token=self.scope["url_route"]["kwargs"]["token"], | ||||
|                 session__session__session_key=self.scope["session"].session_key, | ||||
|             ) | ||||
|             ConnectionToken.filter_not_expired(token=self.scope["url_route"]["kwargs"]["token"]) | ||||
|             .select_related("endpoint", "provider", "session", "session__user") | ||||
|             .first() | ||||
|         ) | ||||
|  | ||||
| @ -2,11 +2,13 @@ | ||||
|  | ||||
| from asgiref.sync import async_to_sync | ||||
| from channels.layers import get_channel_layer | ||||
| from django.contrib.auth.signals import user_logged_out | ||||
| from django.core.cache import cache | ||||
| from django.db.models.signals import post_delete, post_save, pre_delete | ||||
| from django.dispatch import receiver | ||||
| from django.http import HttpRequest | ||||
|  | ||||
| from authentik.core.models import AuthenticatedSession | ||||
| from authentik.core.models import AuthenticatedSession, User | ||||
| from authentik.providers.rac.api.endpoints import user_endpoint_cache_key | ||||
| from authentik.providers.rac.consumer_client import ( | ||||
|     RAC_CLIENT_GROUP_SESSION, | ||||
| @ -15,6 +17,21 @@ from authentik.providers.rac.consumer_client import ( | ||||
| from authentik.providers.rac.models import ConnectionToken, Endpoint | ||||
|  | ||||
|  | ||||
| @receiver(user_logged_out) | ||||
| def user_logged_out_session(sender, request: HttpRequest, user: User, **_): | ||||
|     """Disconnect any open RAC connections""" | ||||
|     if not request.session or not request.session.session_key: | ||||
|         return | ||||
|     layer = get_channel_layer() | ||||
|     async_to_sync(layer.group_send)( | ||||
|         RAC_CLIENT_GROUP_SESSION | ||||
|         % { | ||||
|             "session": request.session.session_key, | ||||
|         }, | ||||
|         {"type": "event.disconnect", "reason": "session_logout"}, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @receiver(pre_delete, sender=AuthenticatedSession) | ||||
| def user_session_deleted(sender, instance: AuthenticatedSession, **_): | ||||
|     layer = get_channel_layer() | ||||
|  | ||||
| @ -49,7 +49,6 @@ class TestEndpointsAPI(APITestCase): | ||||
|         self.assertJSONEqual( | ||||
|             response.content.decode(), | ||||
|             { | ||||
|                 "autocomplete": {}, | ||||
|                 "pagination": { | ||||
|                     "next": 0, | ||||
|                     "previous": 0, | ||||
| @ -102,7 +101,6 @@ class TestEndpointsAPI(APITestCase): | ||||
|         self.assertJSONEqual( | ||||
|             response.content.decode(), | ||||
|             { | ||||
|                 "autocomplete": {}, | ||||
|                 "pagination": { | ||||
|                     "next": 0, | ||||
|                     "previous": 0, | ||||
|  | ||||
| @ -87,22 +87,3 @@ class TestRACViews(APITestCase): | ||||
|         ) | ||||
|         body = loads(flow_response.content) | ||||
|         self.assertEqual(body["component"], "ak-stage-access-denied") | ||||
|  | ||||
|     def test_different_session(self): | ||||
|         """Test request""" | ||||
|         self.client.force_login(self.user) | ||||
|         response = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_providers_rac:start", | ||||
|                 kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)}, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 302) | ||||
|         flow_response = self.client.get( | ||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) | ||||
|         ) | ||||
|         body = loads(flow_response.content) | ||||
|         next_url = body["to"] | ||||
|         self.client.logout() | ||||
|         final_response = self.client.get(next_url) | ||||
|         self.assertEqual(final_response.url, reverse("authentik_core:if-user")) | ||||
|  | ||||
| @ -20,9 +20,6 @@ from authentik.lib.utils.time import timedelta_from_string | ||||
| from authentik.policies.engine import PolicyEngine | ||||
| from authentik.policies.views import PolicyAccessView | ||||
| from authentik.providers.rac.models import ConnectionToken, Endpoint, RACProvider | ||||
| from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT | ||||
|  | ||||
| PLAN_CONNECTION_SETTINGS = "connection_settings" | ||||
|  | ||||
|  | ||||
| class RACStartView(PolicyAccessView): | ||||
| @ -68,10 +65,7 @@ class RACInterface(InterfaceView): | ||||
|  | ||||
|     def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: | ||||
|         # Early sanity check to ensure token still exists | ||||
|         token = ConnectionToken.filter_not_expired( | ||||
|             token=self.kwargs["token"], | ||||
|             session__session__session_key=request.session.session_key, | ||||
|         ).first() | ||||
|         token = ConnectionToken.filter_not_expired(token=self.kwargs["token"]).first() | ||||
|         if not token: | ||||
|             return redirect("authentik_core:if-user") | ||||
|         self.token = token | ||||
| @ -115,15 +109,10 @@ class RACFinalStage(RedirectStage): | ||||
|         return super().dispatch(request, *args, **kwargs) | ||||
|  | ||||
|     def get_challenge(self, *args, **kwargs) -> RedirectChallenge: | ||||
|         settings = self.executor.plan.context.get(PLAN_CONNECTION_SETTINGS) | ||||
|         if not settings: | ||||
|             settings = self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}).get( | ||||
|                 PLAN_CONNECTION_SETTINGS | ||||
|             ) | ||||
|         token = ConnectionToken.objects.create( | ||||
|             provider=self.provider, | ||||
|             endpoint=self.endpoint, | ||||
|             settings=settings or {}, | ||||
|             settings=self.executor.plan.context.get("connection_settings", {}), | ||||
|             session=self.request.session["authenticatedsession"], | ||||
|             expires=now() + timedelta_from_string(self.provider.connection_expiry), | ||||
|             expiring=True, | ||||
|  | ||||
| @ -5,6 +5,7 @@ from itertools import batched | ||||
| from django.db import transaction | ||||
| from pydantic import ValidationError | ||||
| from pydanticscim.group import GroupMember | ||||
| from pydanticscim.responses import PatchOp | ||||
|  | ||||
| from authentik.core.models import Group | ||||
| from authentik.lib.sync.mapper import PropertyMappingManager | ||||
| @ -19,12 +20,7 @@ from authentik.providers.scim.clients.base import SCIMClient | ||||
| from authentik.providers.scim.clients.exceptions import ( | ||||
|     SCIMRequestException, | ||||
| ) | ||||
| from authentik.providers.scim.clients.schema import ( | ||||
|     SCIM_GROUP_SCHEMA, | ||||
|     PatchOp, | ||||
|     PatchOperation, | ||||
|     PatchRequest, | ||||
| ) | ||||
| from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest | ||||
| from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema | ||||
| from authentik.providers.scim.models import ( | ||||
|     SCIMMapping, | ||||
|  | ||||
| @ -1,7 +1,5 @@ | ||||
| """Custom SCIM schemas""" | ||||
|  | ||||
| from enum import Enum | ||||
|  | ||||
| from pydantic import Field | ||||
| from pydanticscim.group import Group as BaseGroup | ||||
| from pydanticscim.responses import PatchOperation as BasePatchOperation | ||||
| @ -67,21 +65,6 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class PatchOp(str, Enum): | ||||
|  | ||||
|     replace = "replace" | ||||
|     remove = "remove" | ||||
|     add = "add" | ||||
|  | ||||
|     @classmethod | ||||
|     def _missing_(cls, value): | ||||
|         value = value.lower() | ||||
|         for member in cls: | ||||
|             if member.lower() == value: | ||||
|                 return member | ||||
|         return None | ||||
|  | ||||
|  | ||||
| class PatchRequest(BasePatchRequest): | ||||
|     """PatchRequest which correctly sets schemas""" | ||||
|  | ||||
| @ -91,7 +74,6 @@ class PatchRequest(BasePatchRequest): | ||||
| class PatchOperation(BasePatchOperation): | ||||
|     """PatchOperation with optional path""" | ||||
|  | ||||
|     op: PatchOp | ||||
|     path: str | None | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -44,7 +44,6 @@ class TestRBACRoleAPI(APITestCase): | ||||
|         self.assertJSONEqual( | ||||
|             res.content.decode(), | ||||
|             { | ||||
|                 "autocomplete": {}, | ||||
|                 "pagination": { | ||||
|                     "next": 0, | ||||
|                     "previous": 0, | ||||
|  | ||||
| @ -46,7 +46,6 @@ class TestRBACUserAPI(APITestCase): | ||||
|         self.assertJSONEqual( | ||||
|             res.content.decode(), | ||||
|             { | ||||
|                 "autocomplete": {}, | ||||
|                 "pagination": { | ||||
|                     "next": 0, | ||||
|                     "previous": 0, | ||||
|  | ||||
| @ -38,7 +38,6 @@ class TestAPIPerms(APITestCase): | ||||
|         self.assertJSONEqual( | ||||
|             res.content.decode(), | ||||
|             { | ||||
|                 "autocomplete": {}, | ||||
|                 "pagination": { | ||||
|                     "next": 0, | ||||
|                     "previous": 0, | ||||
| @ -74,7 +73,6 @@ class TestAPIPerms(APITestCase): | ||||
|         self.assertJSONEqual( | ||||
|             res.content.decode(), | ||||
|             { | ||||
|                 "autocomplete": {}, | ||||
|                 "pagination": { | ||||
|                     "next": 0, | ||||
|                     "previous": 0, | ||||
|  | ||||
| @ -9,14 +9,13 @@ https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/ | ||||
|  | ||||
| import django | ||||
| from channels.routing import ProtocolTypeRouter, URLRouter | ||||
| from defusedxml import defuse_stdlib | ||||
| from django.core.asgi import get_asgi_application | ||||
| from sentry_sdk.integrations.asgi import SentryAsgiMiddleware | ||||
|  | ||||
| from authentik.root.setup import setup | ||||
|  | ||||
| # DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py | ||||
|  | ||||
| setup() | ||||
| defuse_stdlib() | ||||
| django.setup() | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -27,7 +27,7 @@ from structlog.stdlib import get_logger | ||||
| from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp | ||||
|  | ||||
| from authentik import get_full_version | ||||
| from authentik.lib.sentry import should_ignore_exception | ||||
| from authentik.lib.sentry import before_send | ||||
| from authentik.lib.utils.errors import exception_to_string | ||||
|  | ||||
| # set the default Django settings module for the 'celery' program. | ||||
| @ -81,7 +81,7 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar | ||||
|  | ||||
|     LOGGER.warning("Task failure", task_id=task_id.replace("-", ""), exc=exception) | ||||
|     CTX_TASK_ID.set(...) | ||||
|     if not should_ignore_exception(exception): | ||||
|     if before_send({}, {"exc_info": (None, exception, None)}) is not None: | ||||
|         Event.new( | ||||
|             EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id | ||||
|         ).save() | ||||
|  | ||||
| @ -1,49 +1,13 @@ | ||||
| """authentik database backend""" | ||||
|  | ||||
| from django.core.checks import Warning | ||||
| from django.db.backends.base.validation import BaseDatabaseValidation | ||||
| from django_tenants.postgresql_backend.base import DatabaseWrapper as BaseDatabaseWrapper | ||||
|  | ||||
| from authentik.lib.config import CONFIG | ||||
|  | ||||
|  | ||||
| class DatabaseValidation(BaseDatabaseValidation): | ||||
|  | ||||
|     def check(self, **kwargs): | ||||
|         return self._check_encoding() | ||||
|  | ||||
|     def _check_encoding(self): | ||||
|         """Throw a warning when the server_encoding is not UTF-8 or | ||||
|         server_encoding and client_encoding are mismatched""" | ||||
|         messages = [] | ||||
|         with self.connection.cursor() as cursor: | ||||
|             cursor.execute("SHOW server_encoding;") | ||||
|             server_encoding = cursor.fetchone()[0] | ||||
|             cursor.execute("SHOW client_encoding;") | ||||
|             client_encoding = cursor.fetchone()[0] | ||||
|             if server_encoding != client_encoding: | ||||
|                 messages.append( | ||||
|                     Warning( | ||||
|                         "PostgreSQL Server and Client encoding are mismatched: Server: " | ||||
|                         f"{server_encoding}, Client: {client_encoding}", | ||||
|                         id="ak.db.W001", | ||||
|                     ) | ||||
|                 ) | ||||
|             if server_encoding != "UTF8": | ||||
|                 messages.append( | ||||
|                     Warning( | ||||
|                         f"PostgreSQL Server encoding is not UTF8: {server_encoding}", | ||||
|                         id="ak.db.W002", | ||||
|                     ) | ||||
|                 ) | ||||
|         return messages | ||||
|  | ||||
|  | ||||
| class DatabaseWrapper(BaseDatabaseWrapper): | ||||
|     """database backend which supports rotating credentials""" | ||||
|  | ||||
|     validation_class = DatabaseValidation | ||||
|  | ||||
|     def get_connection_params(self): | ||||
|         """Refresh DB credentials before getting connection params""" | ||||
|         conn_params = super().get_connection_params() | ||||
|  | ||||
| @ -446,8 +446,6 @@ _DISALLOWED_ITEMS = [ | ||||
|     "MIDDLEWARE", | ||||
|     "AUTHENTICATION_BACKENDS", | ||||
|     "CELERY", | ||||
|     "SPECTACULAR_SETTINGS", | ||||
|     "REST_FRAMEWORK", | ||||
| ] | ||||
|  | ||||
| SILENCED_SYSTEM_CHECKS = [ | ||||
| @ -470,8 +468,6 @@ def _update_settings(app_path: str): | ||||
|         TENANT_APPS.extend(getattr(settings_module, "TENANT_APPS", [])) | ||||
|         MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", [])) | ||||
|         AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", [])) | ||||
|         SPECTACULAR_SETTINGS.update(getattr(settings_module, "SPECTACULAR_SETTINGS", {})) | ||||
|         REST_FRAMEWORK.update(getattr(settings_module, "REST_FRAMEWORK", {})) | ||||
|         CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {})) | ||||
|         for _attr in dir(settings_module): | ||||
|             if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS: | ||||
|  | ||||
| @ -1,26 +0,0 @@ | ||||
| import os | ||||
| import warnings | ||||
|  | ||||
| from cryptography.hazmat.backends.openssl.backend import backend | ||||
| from defusedxml import defuse_stdlib | ||||
|  | ||||
| from authentik.lib.config import CONFIG | ||||
|  | ||||
|  | ||||
| def setup(): | ||||
|     warnings.filterwarnings("ignore", "SelectableGroups dict interface") | ||||
|     warnings.filterwarnings( | ||||
|         "ignore", | ||||
|         "defusedxml.lxml is no longer supported and will be removed in a future release.", | ||||
|     ) | ||||
|     warnings.filterwarnings( | ||||
|         "ignore", | ||||
|         "defusedxml.cElementTree is deprecated, import from defusedxml.ElementTree instead.", | ||||
|     ) | ||||
|  | ||||
|     defuse_stdlib() | ||||
|  | ||||
|     if CONFIG.get_bool("compliance.fips.enabled", False): | ||||
|         backend._enable_fips() | ||||
|  | ||||
|     os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings") | ||||
| @ -11,8 +11,6 @@ from django.contrib.contenttypes.models import ContentType | ||||
| from django.test.runner import DiscoverRunner | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR | ||||
| from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.lib.sentry import sentry_init | ||||
| from authentik.root.signals import post_startup, pre_startup, startup | ||||
| @ -78,9 +76,6 @@ class PytestTestRunner(DiscoverRunner):  # pragma: no cover | ||||
|         for key, value in test_config.items(): | ||||
|             CONFIG.set(key, value) | ||||
|  | ||||
|         ASN_CONTEXT_PROCESSOR.load() | ||||
|         GEOIP_CONTEXT_PROCESSOR.load() | ||||
|  | ||||
|         sentry_init() | ||||
|         self.logger.debug("Test environment configured") | ||||
|  | ||||
|  | ||||
| @ -71,31 +71,37 @@ def ldap_sync_single(source_pk: str): | ||||
|             return | ||||
|         # Delete all sync tasks from the cache | ||||
|         DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete() | ||||
|  | ||||
|         # The order of these operations needs to be preserved as each depends on the previous one(s) | ||||
|         # 1. User and group sync can happen simultaneously | ||||
|         # 2. Membership sync needs to run afterwards | ||||
|         # 3. Finally, user and group deletions can happen simultaneously | ||||
|         user_group_sync = ldap_sync_paginator(source, UserLDAPSynchronizer) + ldap_sync_paginator( | ||||
|             source, GroupLDAPSynchronizer | ||||
|         task = chain( | ||||
|             # User and group sync can happen at once, they have no dependencies on each other | ||||
|             group( | ||||
|                 ldap_sync_paginator(source, UserLDAPSynchronizer) | ||||
|                 + ldap_sync_paginator(source, GroupLDAPSynchronizer), | ||||
|             ), | ||||
|             # Membership sync needs to run afterwards | ||||
|             group( | ||||
|                 ldap_sync_paginator(source, MembershipLDAPSynchronizer), | ||||
|             ), | ||||
|             # Finally, deletions. What we'd really like to do here is something like | ||||
|             # ``` | ||||
|             # user_identifiers = <ldap query> | ||||
|             # User.objects.exclude( | ||||
|             #     usersourceconnection__identifier__in=user_uniqueness_identifiers, | ||||
|             # ).delete() | ||||
|             # ``` | ||||
|             # This runs into performance issues in large installations. So instead we spread the | ||||
|             # work out into three steps: | ||||
|             # 1. Get every object from the LDAP source. | ||||
|             # 2. Mark every object as "safe" in the database. This is quick, but any error could | ||||
|             #    mean deleting users which should not be deleted, so we do it immediately, in | ||||
|             #    large chunks, and only queue the deletion step afterwards. | ||||
|             # 3. Delete every unmarked item. This is slow, so we spread it over many tasks in | ||||
|             #    small chunks. | ||||
|             group( | ||||
|                 ldap_sync_paginator(source, UserLDAPForwardDeletion) | ||||
|                 + ldap_sync_paginator(source, GroupLDAPForwardDeletion), | ||||
|             ), | ||||
|         ) | ||||
|         membership_sync = ldap_sync_paginator(source, MembershipLDAPSynchronizer) | ||||
|         user_group_deletion = ldap_sync_paginator( | ||||
|             source, UserLDAPForwardDeletion | ||||
|         ) + ldap_sync_paginator(source, GroupLDAPForwardDeletion) | ||||
|  | ||||
|         # Celery is buggy with empty groups, so we are careful only to add non-empty groups. | ||||
|         # See https://github.com/celery/celery/issues/9772 | ||||
|         task_groups = [] | ||||
|         if user_group_sync: | ||||
|             task_groups.append(group(user_group_sync)) | ||||
|         if membership_sync: | ||||
|             task_groups.append(group(membership_sync)) | ||||
|         if user_group_deletion: | ||||
|             task_groups.append(group(user_group_deletion)) | ||||
|  | ||||
|         all_tasks = chain(task_groups) | ||||
|         all_tasks() | ||||
|         task() | ||||
|  | ||||
|  | ||||
| def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list: | ||||
|  | ||||
| @ -1,277 +0,0 @@ | ||||
| """Test SCIM Group""" | ||||
|  | ||||
| from json import dumps | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from django.urls import reverse | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import Group | ||||
| from authentik.core.tests.utils import create_test_user | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema | ||||
| from authentik.sources.scim.models import ( | ||||
|     SCIMSource, | ||||
|     SCIMSourceGroup, | ||||
| ) | ||||
| from authentik.sources.scim.views.v2.base import SCIM_CONTENT_TYPE | ||||
|  | ||||
|  | ||||
| class TestSCIMGroups(APITestCase): | ||||
|     """Test SCIM Group view""" | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         self.source = SCIMSource.objects.create(name=generate_id(), slug=generate_id()) | ||||
|  | ||||
|     def test_group_list(self): | ||||
|         """Test full group list""" | ||||
|         response = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                 }, | ||||
|             ), | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|  | ||||
|     def test_group_list_single(self): | ||||
|         """Test full group list (single group)""" | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         user = create_test_user() | ||||
|         group.users.add(user) | ||||
|         SCIMSourceGroup.objects.create( | ||||
|             source=self.source, | ||||
|             group=group, | ||||
|             id=str(uuid4()), | ||||
|         ) | ||||
|         response = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                     "group_id": str(group.pk), | ||||
|                 }, | ||||
|             ), | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, second=200) | ||||
|         SCIMGroupSchema.model_validate_json(response.content, strict=True) | ||||
|  | ||||
|     def test_group_create(self): | ||||
|         """Test group create""" | ||||
|         ext_id = generate_id() | ||||
|         response = self.client.post( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                 }, | ||||
|             ), | ||||
|             data=dumps({"displayName": generate_id(), "externalId": ext_id}), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|         self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists()) | ||||
|         self.assertTrue( | ||||
|             Event.objects.filter( | ||||
|                 action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username | ||||
|             ).exists() | ||||
|         ) | ||||
|  | ||||
|     def test_group_create_members(self): | ||||
|         """Test group create""" | ||||
|         user = create_test_user() | ||||
|         ext_id = generate_id() | ||||
|         response = self.client.post( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                 }, | ||||
|             ), | ||||
|             data=dumps( | ||||
|                 { | ||||
|                     "displayName": generate_id(), | ||||
|                     "externalId": ext_id, | ||||
|                     "members": [{"value": str(user.uuid)}], | ||||
|                 } | ||||
|             ), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|         self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists()) | ||||
|         self.assertTrue( | ||||
|             Event.objects.filter( | ||||
|                 action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username | ||||
|             ).exists() | ||||
|         ) | ||||
|  | ||||
|     def test_group_create_members_empty(self): | ||||
|         """Test group create""" | ||||
|         ext_id = generate_id() | ||||
|         response = self.client.post( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                 }, | ||||
|             ), | ||||
|             data=dumps({"displayName": generate_id(), "externalId": ext_id, "members": []}), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|         self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists()) | ||||
|         self.assertTrue( | ||||
|             Event.objects.filter( | ||||
|                 action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username | ||||
|             ).exists() | ||||
|         ) | ||||
|  | ||||
|     def test_group_create_duplicate(self): | ||||
|         """Test group create (duplicate)""" | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) | ||||
|         ext_id = generate_id() | ||||
|         response = self.client.post( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                 }, | ||||
|             ), | ||||
|             data=dumps( | ||||
|                 {"displayName": generate_id(), "externalId": ext_id, "id": str(existing.group.pk)} | ||||
|             ), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 409) | ||||
|         self.assertJSONEqual( | ||||
|             response.content, | ||||
|             { | ||||
|                 "detail": "Group with ID exists already.", | ||||
|                 "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], | ||||
|                 "scimType": "uniqueness", | ||||
|                 "status": 409, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_group_update(self): | ||||
|         """Test group update""" | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) | ||||
|         ext_id = generate_id() | ||||
|         response = self.client.put( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={"source_slug": self.source.slug, "group_id": group.pk}, | ||||
|             ), | ||||
|             data=dumps( | ||||
|                 {"displayName": generate_id(), "externalId": ext_id, "id": str(existing.pk)} | ||||
|             ), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, second=200) | ||||
|  | ||||
|     def test_group_update_non_existent(self): | ||||
|         """Test group update""" | ||||
|         ext_id = generate_id() | ||||
|         response = self.client.put( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                     "group_id": str(uuid4()), | ||||
|                 }, | ||||
|             ), | ||||
|             data=dumps({"displayName": generate_id(), "externalId": ext_id, "id": ""}), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, second=404) | ||||
|         self.assertJSONEqual( | ||||
|             response.content, | ||||
|             { | ||||
|                 "detail": "Group not found.", | ||||
|                 "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], | ||||
|                 "status": 404, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_group_patch_add(self): | ||||
|         """Test group patch""" | ||||
|         user = create_test_user() | ||||
|  | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) | ||||
|         response = self.client.patch( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={"source_slug": self.source.slug, "group_id": group.pk}, | ||||
|             ), | ||||
|             data=dumps( | ||||
|                 { | ||||
|                     "Operations": [ | ||||
|                         { | ||||
|                             "op": "Add", | ||||
|                             "path": "members", | ||||
|                             "value": {"value": str(user.uuid)}, | ||||
|                         } | ||||
|                     ] | ||||
|                 } | ||||
|             ), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, second=200) | ||||
|         self.assertTrue(group.users.filter(pk=user.pk).exists()) | ||||
|  | ||||
|     def test_group_patch_remove(self): | ||||
|         """Test group patch""" | ||||
|         user = create_test_user() | ||||
|  | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         group.users.add(user) | ||||
|         SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) | ||||
|         response = self.client.patch( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={"source_slug": self.source.slug, "group_id": group.pk}, | ||||
|             ), | ||||
|             data=dumps( | ||||
|                 { | ||||
|                     "Operations": [ | ||||
|                         { | ||||
|                             "op": "remove", | ||||
|                             "path": "members", | ||||
|                             "value": {"value": str(user.uuid)}, | ||||
|                         } | ||||
|                     ] | ||||
|                 } | ||||
|             ), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, second=200) | ||||
|         self.assertFalse(group.users.filter(pk=user.pk).exists()) | ||||
|  | ||||
|     def test_group_delete(self): | ||||
|         """Test group delete""" | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) | ||||
|         response = self.client.delete( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={"source_slug": self.source.slug, "group_id": group.pk}, | ||||
|             ), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, second=204) | ||||
| @ -177,51 +177,3 @@ class TestSCIMUsers(APITestCase): | ||||
|             SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"], | ||||
|             "0123456789", | ||||
|         ) | ||||
|  | ||||
|     def test_user_update(self): | ||||
|         """Test user update""" | ||||
|         user = create_test_user() | ||||
|         existing = SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4()) | ||||
|         ext_id = generate_id() | ||||
|         response = self.client.put( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-users", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                     "user_id": str(user.uuid), | ||||
|                 }, | ||||
|             ), | ||||
|             data=dumps( | ||||
|                 { | ||||
|                     "id": str(existing.pk), | ||||
|                     "userName": generate_id(), | ||||
|                     "externalId": ext_id, | ||||
|                     "emails": [ | ||||
|                         { | ||||
|                             "primary": True, | ||||
|                             "value": user.email, | ||||
|                         } | ||||
|                     ], | ||||
|                 } | ||||
|             ), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|  | ||||
|     def test_user_delete(self): | ||||
|         """Test user delete""" | ||||
|         user = create_test_user() | ||||
|         SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4()) | ||||
|         response = self.client.delete( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-users", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                     "user_id": str(user.uuid), | ||||
|                 }, | ||||
|             ), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 204) | ||||
|  | ||||
| @ -8,7 +8,6 @@ from rest_framework.authentication import BaseAuthentication, get_authorization_ | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.views import APIView | ||||
|  | ||||
| from authentik.core.middleware import CTX_AUTH_VIA | ||||
| from authentik.core.models import Token, TokenIntents, User | ||||
| from authentik.sources.scim.models import SCIMSource | ||||
|  | ||||
| @ -27,7 +26,6 @@ class SCIMTokenAuth(BaseAuthentication): | ||||
|         _username, _, password = b64decode(key.encode()).decode().partition(":") | ||||
|         token = self.check_token(password, source_slug) | ||||
|         if token: | ||||
|             CTX_AUTH_VIA.set("scim_basic") | ||||
|             return (token.user, token) | ||||
|         return None | ||||
|  | ||||
| @ -54,5 +52,4 @@ class SCIMTokenAuth(BaseAuthentication): | ||||
|         token = self.check_token(key, source_slug) | ||||
|         if not token: | ||||
|             return None | ||||
|         CTX_AUTH_VIA.set("scim_token") | ||||
|         return (token.user, token) | ||||
|  | ||||
| @ -1,11 +1,13 @@ | ||||
| """SCIM Utils""" | ||||
|  | ||||
| from typing import Any | ||||
| from urllib.parse import urlparse | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.core.paginator import Page, Paginator | ||||
| from django.db.models import Q, QuerySet | ||||
| from django.http import HttpRequest | ||||
| from django.urls import resolve | ||||
| from rest_framework.parsers import JSONParser | ||||
| from rest_framework.permissions import IsAuthenticated | ||||
| from rest_framework.renderers import JSONRenderer | ||||
| @ -44,7 +46,7 @@ class SCIMView(APIView): | ||||
|     logger: BoundLogger | ||||
|  | ||||
|     permission_classes = [IsAuthenticated] | ||||
|     parser_classes = [SCIMParser, JSONParser] | ||||
|     parser_classes = [SCIMParser] | ||||
|     renderer_classes = [SCIMRenderer] | ||||
|  | ||||
|     def setup(self, request: HttpRequest, *args: Any, **kwargs: Any) -> None: | ||||
| @ -54,6 +56,28 @@ class SCIMView(APIView): | ||||
|     def get_authenticators(self): | ||||
|         return [SCIMTokenAuth(self)] | ||||
|  | ||||
|     def patch_resolve_value(self, raw_value: dict) -> User | Group | None: | ||||
|         """Attempt to resolve a raw `value` attribute of a patch operation into | ||||
|         a database model""" | ||||
|         model = User | ||||
|         query = {} | ||||
|         if "$ref" in raw_value: | ||||
|             url = urlparse(raw_value["$ref"]) | ||||
|             if match := resolve(url.path): | ||||
|                 if match.url_name == "v2-users": | ||||
|                     model = User | ||||
|                     query = {"pk": int(match.kwargs["user_id"])} | ||||
|         elif "type" in raw_value: | ||||
|             match raw_value["type"]: | ||||
|                 case "User": | ||||
|                     model = User | ||||
|                     query = {"pk": int(raw_value["value"])} | ||||
|                 case "Group": | ||||
|                     model = Group | ||||
|         else: | ||||
|             return None | ||||
|         return model.objects.filter(**query).first() | ||||
|  | ||||
|     def filter_parse(self, request: Request): | ||||
|         """Parse the path of a Patch Operation""" | ||||
|         path = request.query_params.get("filter") | ||||
|  | ||||
| @ -1,58 +0,0 @@ | ||||
| from enum import Enum | ||||
|  | ||||
| from pydanticscim.responses import SCIMError as BaseSCIMError | ||||
| from rest_framework.exceptions import ValidationError | ||||
|  | ||||
|  | ||||
| class SCIMErrorTypes(Enum): | ||||
|     invalid_filter = "invalidFilter" | ||||
|     too_many = "tooMany" | ||||
|     uniqueness = "uniqueness" | ||||
|     mutability = "mutability" | ||||
|     invalid_syntax = "invalidSyntax" | ||||
|     invalid_path = "invalidPath" | ||||
|     no_target = "noTarget" | ||||
|     invalid_value = "invalidValue" | ||||
|     invalid_vers = "invalidVers" | ||||
|     sensitive = "sensitive" | ||||
|  | ||||
|  | ||||
| class SCIMError(BaseSCIMError): | ||||
|     scimType: SCIMErrorTypes | None = None | ||||
|     detail: str | None = None | ||||
|  | ||||
|  | ||||
| class SCIMValidationError(ValidationError): | ||||
|     status_code = 400 | ||||
|     default_detail = SCIMError(scimType=SCIMErrorTypes.invalid_syntax, status=400) | ||||
|  | ||||
|     def __init__(self, detail: SCIMError | None): | ||||
|         if detail is None: | ||||
|             detail = self.default_detail | ||||
|         detail.status = self.status_code | ||||
|         self.detail = detail.model_dump(mode="json", exclude_none=True) | ||||
|  | ||||
|  | ||||
| class SCIMConflictError(SCIMValidationError): | ||||
|     status_code = 409 | ||||
|  | ||||
|     def __init__(self, detail: str): | ||||
|         super().__init__( | ||||
|             SCIMError( | ||||
|                 detail=detail, | ||||
|                 scimType=SCIMErrorTypes.uniqueness, | ||||
|                 status=self.status_code, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SCIMNotFoundError(SCIMValidationError): | ||||
|     status_code = 404 | ||||
|  | ||||
|     def __init__(self, detail: str): | ||||
|         super().__init__( | ||||
|             SCIMError( | ||||
|                 detail=detail, | ||||
|                 status=self.status_code, | ||||
|             ) | ||||
|         ) | ||||
| @ -4,25 +4,19 @@ from uuid import uuid4 | ||||
|  | ||||
| from django.db.models import Q | ||||
| from django.db.transaction import atomic | ||||
| from django.http import QueryDict | ||||
| from django.http import Http404, QueryDict | ||||
| from django.urls import reverse | ||||
| from pydantic import ValidationError as PydanticValidationError | ||||
| from pydanticscim.group import GroupMember | ||||
| from rest_framework.exceptions import ValidationError | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
| from scim2_filter_parser.attr_paths import AttrPath | ||||
|  | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOp, PatchOperation | ||||
| from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA | ||||
| from authentik.providers.scim.clients.schema import Group as SCIMGroupModel | ||||
| from authentik.sources.scim.models import SCIMSourceGroup | ||||
| from authentik.sources.scim.views.v2.base import SCIMObjectView | ||||
| from authentik.sources.scim.views.v2.exceptions import ( | ||||
|     SCIMConflictError, | ||||
|     SCIMNotFoundError, | ||||
|     SCIMValidationError, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class GroupsView(SCIMObjectView): | ||||
| @ -33,7 +27,7 @@ class GroupsView(SCIMObjectView): | ||||
|     def group_to_scim(self, scim_group: SCIMSourceGroup) -> dict: | ||||
|         """Convert Group to SCIM data""" | ||||
|         payload = SCIMGroupModel( | ||||
|             schemas=[SCIM_GROUP_SCHEMA], | ||||
|             schemas=[SCIM_USER_SCHEMA], | ||||
|             id=str(scim_group.group.pk), | ||||
|             externalId=scim_group.id, | ||||
|             displayName=scim_group.group.name, | ||||
| @ -64,7 +58,7 @@ class GroupsView(SCIMObjectView): | ||||
|         if group_id: | ||||
|             connection = base_query.filter(source=self.source, group__group_uuid=group_id).first() | ||||
|             if not connection: | ||||
|                 raise SCIMNotFoundError("Group not found.") | ||||
|                 raise Http404 | ||||
|             return Response(self.group_to_scim(connection)) | ||||
|         connections = ( | ||||
|             base_query.filter(source=self.source).order_by("pk").filter(self.filter_parse(request)) | ||||
| @ -125,7 +119,7 @@ class GroupsView(SCIMObjectView): | ||||
|         ).first() | ||||
|         if connection: | ||||
|             self.logger.debug("Found existing group") | ||||
|             raise SCIMConflictError("Group with ID exists already.") | ||||
|             return Response(status=409) | ||||
|         connection = self.update_group(None, request.data) | ||||
|         return Response(self.group_to_scim(connection), status=201) | ||||
|  | ||||
| @ -135,44 +129,10 @@ class GroupsView(SCIMObjectView): | ||||
|             source=self.source, group__group_uuid=group_id | ||||
|         ).first() | ||||
|         if not connection: | ||||
|             raise SCIMNotFoundError("Group not found.") | ||||
|             raise Http404 | ||||
|         connection = self.update_group(connection, request.data) | ||||
|         return Response(self.group_to_scim(connection), status=200) | ||||
|  | ||||
|     @atomic | ||||
|     def patch(self, request: Request, group_id: str, **kwargs) -> Response: | ||||
|         """Patch group handler""" | ||||
|         connection = SCIMSourceGroup.objects.filter( | ||||
|             source=self.source, group__group_uuid=group_id | ||||
|         ).first() | ||||
|         if not connection: | ||||
|             raise SCIMNotFoundError("Group not found.") | ||||
|  | ||||
|         for _op in request.data.get("Operations", []): | ||||
|             operation = PatchOperation.model_validate(_op) | ||||
|             if operation.op.lower() not in ["add", "remove", "replace"]: | ||||
|                 raise SCIMValidationError() | ||||
|             attr_path = AttrPath(f'{operation.path} eq ""', {}) | ||||
|             if attr_path.first_path == ("members", None, None): | ||||
|                 # FIXME: this can probably be de-duplicated | ||||
|                 if operation.op == PatchOp.add: | ||||
|                     if not isinstance(operation.value, list): | ||||
|                         operation.value = [operation.value] | ||||
|                     query = Q() | ||||
|                     for member in operation.value: | ||||
|                         query |= Q(uuid=member["value"]) | ||||
|                     if query: | ||||
|                         connection.group.users.add(*User.objects.filter(query)) | ||||
|                 elif operation.op == PatchOp.remove: | ||||
|                     if not isinstance(operation.value, list): | ||||
|                         operation.value = [operation.value] | ||||
|                     query = Q() | ||||
|                     for member in operation.value: | ||||
|                         query |= Q(uuid=member["value"]) | ||||
|                     if query: | ||||
|                         connection.group.users.remove(*User.objects.filter(query)) | ||||
|         return Response(self.group_to_scim(connection), status=200) | ||||
|  | ||||
|     @atomic | ||||
|     def delete(self, request: Request, group_id: str, **kwargs) -> Response: | ||||
|         """Delete group handler""" | ||||
| @ -180,7 +140,7 @@ class GroupsView(SCIMObjectView): | ||||
|             source=self.source, group__group_uuid=group_id | ||||
|         ).first() | ||||
|         if not connection: | ||||
|             raise SCIMNotFoundError("Group not found.") | ||||
|             raise Http404 | ||||
|         connection.group.delete() | ||||
|         connection.delete() | ||||
|         return Response(status=204) | ||||
|  | ||||
| @ -1,11 +1,11 @@ | ||||
| """SCIM Meta views""" | ||||
|  | ||||
| from django.http import Http404 | ||||
| from django.urls import reverse | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
|  | ||||
| from authentik.sources.scim.views.v2.base import SCIMView | ||||
| from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError | ||||
|  | ||||
|  | ||||
| class ResourceTypesView(SCIMView): | ||||
| @ -138,7 +138,7 @@ class ResourceTypesView(SCIMView): | ||||
|             resource = [x for x in resource_types if x.get("id") == resource_type] | ||||
|             if resource: | ||||
|                 return Response(resource[0]) | ||||
|             raise SCIMNotFoundError("Resource not found.") | ||||
|             raise Http404 | ||||
|         return Response( | ||||
|             { | ||||
|                 "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], | ||||
|  | ||||
| @ -3,12 +3,12 @@ | ||||
| from json import loads | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.http import Http404 | ||||
| from django.urls import reverse | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
|  | ||||
| from authentik.sources.scim.views.v2.base import SCIMView | ||||
| from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError | ||||
|  | ||||
| with open( | ||||
|     settings.BASE_DIR / "authentik" / "sources" / "scim" / "schemas" / "schema.json", | ||||
| @ -44,7 +44,7 @@ class SchemaView(SCIMView): | ||||
|             schema = [x for x in schemas if x.get("id") == schema_uri] | ||||
|             if schema: | ||||
|                 return Response(schema[0]) | ||||
|             raise SCIMNotFoundError("Schema not found.") | ||||
|             raise Http404 | ||||
|         return Response( | ||||
|             { | ||||
|                 "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], | ||||
|  | ||||
| @ -33,8 +33,6 @@ class ServiceProviderConfigView(SCIMView): | ||||
|             { | ||||
|                 "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"], | ||||
|                 "authenticationSchemes": auth_schemas, | ||||
|                 # We only support patch for groups currently, so don't broadly advertise it. | ||||
|                 # Implementations that require Group patch will use it regardless of this flag. | ||||
|                 "patch": {"supported": False}, | ||||
|                 "bulk": {"supported": False, "maxOperations": 0, "maxPayloadSize": 0}, | ||||
|                 "filter": { | ||||
|  | ||||
| @ -4,7 +4,7 @@ from uuid import uuid4 | ||||
|  | ||||
| from django.db.models import Q | ||||
| from django.db.transaction import atomic | ||||
| from django.http import QueryDict | ||||
| from django.http import Http404, QueryDict | ||||
| from django.urls import reverse | ||||
| from pydanticscim.user import Email, EmailKind, Name | ||||
| from rest_framework.exceptions import ValidationError | ||||
| @ -16,7 +16,6 @@ from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA | ||||
| from authentik.providers.scim.clients.schema import User as SCIMUserModel | ||||
| from authentik.sources.scim.models import SCIMSourceUser | ||||
| from authentik.sources.scim.views.v2.base import SCIMObjectView | ||||
| from authentik.sources.scim.views.v2.exceptions import SCIMConflictError, SCIMNotFoundError | ||||
|  | ||||
|  | ||||
| class UsersView(SCIMObjectView): | ||||
| @ -70,7 +69,7 @@ class UsersView(SCIMObjectView): | ||||
|                 .first() | ||||
|             ) | ||||
|             if not connection: | ||||
|                 raise SCIMNotFoundError("User not found.") | ||||
|                 raise Http404 | ||||
|             return Response(self.user_to_scim(connection)) | ||||
|         connections = ( | ||||
|             SCIMSourceUser.objects.filter(source=self.source).select_related("user").order_by("pk") | ||||
| @ -123,7 +122,7 @@ class UsersView(SCIMObjectView): | ||||
|         ).first() | ||||
|         if connection: | ||||
|             self.logger.debug("Found existing user") | ||||
|             raise SCIMConflictError("Group with ID exists already.") | ||||
|             return Response(status=409) | ||||
|         connection = self.update_user(None, request.data) | ||||
|         return Response(self.user_to_scim(connection), status=201) | ||||
|  | ||||
| @ -131,7 +130,7 @@ class UsersView(SCIMObjectView): | ||||
|         """Update user handler""" | ||||
|         connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first() | ||||
|         if not connection: | ||||
|             raise SCIMNotFoundError("User not found.") | ||||
|             raise Http404 | ||||
|         self.update_user(connection, request.data) | ||||
|         return Response(self.user_to_scim(connection), status=200) | ||||
|  | ||||
| @ -140,7 +139,7 @@ class UsersView(SCIMObjectView): | ||||
|         """Delete user handler""" | ||||
|         connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first() | ||||
|         if not connection: | ||||
|             raise SCIMNotFoundError("User not found.") | ||||
|             raise Http404 | ||||
|         connection.user.delete() | ||||
|         connection.delete() | ||||
|         return Response(status=204) | ||||
|  | ||||
| @ -1,7 +1,6 @@ | ||||
| """Validation stage challenge checking""" | ||||
|  | ||||
| from json import loads | ||||
| from typing import TYPE_CHECKING | ||||
| from urllib.parse import urlencode | ||||
|  | ||||
| from django.http import HttpRequest | ||||
| @ -37,12 +36,10 @@ from authentik.stages.authenticator_email.models import EmailDevice | ||||
| from authentik.stages.authenticator_sms.models import SMSDevice | ||||
| from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses | ||||
| from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice | ||||
| from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE | ||||
| from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE | ||||
| from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| if TYPE_CHECKING: | ||||
|     from authentik.stages.authenticator_validate.stage import AuthenticatorValidateStageView | ||||
|  | ||||
|  | ||||
| class DeviceChallenge(PassiveSerializer): | ||||
| @ -55,11 +52,11 @@ class DeviceChallenge(PassiveSerializer): | ||||
|  | ||||
|  | ||||
| def get_challenge_for_device( | ||||
|     stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage, device: Device | ||||
|     request: HttpRequest, stage: AuthenticatorValidateStage, device: Device | ||||
| ) -> dict: | ||||
|     """Generate challenge for a single device""" | ||||
|     if isinstance(device, WebAuthnDevice): | ||||
|         return get_webauthn_challenge(stage_view, stage, device) | ||||
|         return get_webauthn_challenge(request, stage, device) | ||||
|     if isinstance(device, EmailDevice): | ||||
|         return {"email": mask_email(device.email)} | ||||
|     # Code-based challenges have no hints | ||||
| @ -67,30 +64,26 @@ def get_challenge_for_device( | ||||
|  | ||||
|  | ||||
| def get_webauthn_challenge_without_user( | ||||
|     stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage | ||||
|     request: HttpRequest, stage: AuthenticatorValidateStage | ||||
| ) -> dict: | ||||
|     """Same as `get_webauthn_challenge`, but allows any client device. We can then later check | ||||
|     who the device belongs to.""" | ||||
|     stage_view.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None) | ||||
|     request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None) | ||||
|     authentication_options = generate_authentication_options( | ||||
|         rp_id=get_rp_id(stage_view.request), | ||||
|         rp_id=get_rp_id(request), | ||||
|         allow_credentials=[], | ||||
|         user_verification=UserVerificationRequirement(stage.webauthn_user_verification), | ||||
|     ) | ||||
|     stage_view.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = ( | ||||
|         authentication_options.challenge | ||||
|     ) | ||||
|     request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge | ||||
|  | ||||
|     return loads(options_to_json(authentication_options)) | ||||
|  | ||||
|  | ||||
| def get_webauthn_challenge( | ||||
|     stage_view: "AuthenticatorValidateStageView", | ||||
|     stage: AuthenticatorValidateStage, | ||||
|     device: WebAuthnDevice | None = None, | ||||
|     request: HttpRequest, stage: AuthenticatorValidateStage, device: WebAuthnDevice | None = None | ||||
| ) -> dict: | ||||
|     """Send the client a challenge that we'll check later""" | ||||
|     stage_view.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None) | ||||
|     request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None) | ||||
|  | ||||
|     allowed_credentials = [] | ||||
|  | ||||
| @ -101,14 +94,12 @@ def get_webauthn_challenge( | ||||
|             allowed_credentials.append(user_device.descriptor) | ||||
|  | ||||
|     authentication_options = generate_authentication_options( | ||||
|         rp_id=get_rp_id(stage_view.request), | ||||
|         rp_id=get_rp_id(request), | ||||
|         allow_credentials=allowed_credentials, | ||||
|         user_verification=UserVerificationRequirement(stage.webauthn_user_verification), | ||||
|     ) | ||||
|  | ||||
|     stage_view.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = ( | ||||
|         authentication_options.challenge | ||||
|     ) | ||||
|     request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge | ||||
|  | ||||
|     return loads(options_to_json(authentication_options)) | ||||
|  | ||||
| @ -155,7 +146,7 @@ def validate_challenge_code(code: str, stage_view: StageView, user: User) -> Dev | ||||
| def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -> Device: | ||||
|     """Validate WebAuthn Challenge""" | ||||
|     request = stage_view.request | ||||
|     challenge = stage_view.executor.plan.context.get(PLAN_CONTEXT_WEBAUTHN_CHALLENGE) | ||||
|     challenge = request.session.get(SESSION_KEY_WEBAUTHN_CHALLENGE) | ||||
|     stage: AuthenticatorValidateStage = stage_view.executor.current_stage | ||||
|     try: | ||||
|         credential = parse_authentication_credential_json(data) | ||||
|  | ||||
| @ -224,7 +224,7 @@ class AuthenticatorValidateStageView(ChallengeStageView): | ||||
|                 data={ | ||||
|                     "device_class": device_class, | ||||
|                     "device_uid": device.pk, | ||||
|                     "challenge": get_challenge_for_device(self, stage, device), | ||||
|                     "challenge": get_challenge_for_device(self.request, stage, device), | ||||
|                     "last_used": device.last_used, | ||||
|                 } | ||||
|             ) | ||||
| @ -243,7 +243,7 @@ class AuthenticatorValidateStageView(ChallengeStageView): | ||||
|                 "device_class": DeviceClasses.WEBAUTHN, | ||||
|                 "device_uid": -1, | ||||
|                 "challenge": get_webauthn_challenge_without_user( | ||||
|                     self, | ||||
|                     self.request, | ||||
|                     self.executor.current_stage, | ||||
|                 ), | ||||
|                 "last_used": None, | ||||
|  | ||||
| @ -31,7 +31,7 @@ from authentik.stages.authenticator_webauthn.models import ( | ||||
|     WebAuthnDevice, | ||||
|     WebAuthnDeviceType, | ||||
| ) | ||||
| from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE | ||||
| from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE | ||||
| from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import | ||||
| from authentik.stages.identification.models import IdentificationStage, UserFields | ||||
| from authentik.stages.user_login.models import UserLoginStage | ||||
| @ -103,11 +103,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): | ||||
|             device_classes=[DeviceClasses.WEBAUTHN], | ||||
|             webauthn_user_verification=UserVerification.PREFERRED, | ||||
|         ) | ||||
|         plan = FlowPlan("") | ||||
|         stage_view = AuthenticatorValidateStageView( | ||||
|             FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request | ||||
|         ) | ||||
|         challenge = get_challenge_for_device(stage_view, stage, webauthn_device) | ||||
|         challenge = get_challenge_for_device(request, stage, webauthn_device) | ||||
|         del challenge["challenge"] | ||||
|         self.assertEqual( | ||||
|             challenge, | ||||
| @ -126,9 +122,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): | ||||
|  | ||||
|         with self.assertRaises(ValidationError): | ||||
|             validate_challenge_webauthn( | ||||
|                 {}, | ||||
|                 StageView(FlowExecutorView(current_stage=stage, plan=plan), request=request), | ||||
|                 self.user, | ||||
|                 {}, StageView(FlowExecutorView(current_stage=stage), request=request), self.user | ||||
|             ) | ||||
|  | ||||
|     def test_device_challenge_webauthn_restricted(self): | ||||
| @ -199,35 +193,22 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): | ||||
|             sign_count=0, | ||||
|             rp_id=generate_id(), | ||||
|         ) | ||||
|         plan = FlowPlan("") | ||||
|         plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes( | ||||
|             "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA" | ||||
|         ) | ||||
|         stage_view = AuthenticatorValidateStageView( | ||||
|             FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request | ||||
|         ) | ||||
|         challenge = get_challenge_for_device(stage_view, stage, webauthn_device) | ||||
|         challenge = get_challenge_for_device(request, stage, webauthn_device) | ||||
|         webauthn_challenge = request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] | ||||
|         self.assertEqual( | ||||
|             challenge["allowCredentials"], | ||||
|             [ | ||||
|                 { | ||||
|                     "id": "QKZ97ASJAOIDyipAs6mKUxDUZgDrWrbAsUb5leL7-oU", | ||||
|                     "type": "public-key", | ||||
|                 } | ||||
|             ], | ||||
|         ) | ||||
|         self.assertIsNotNone(challenge["challenge"]) | ||||
|         self.assertEqual( | ||||
|             challenge["rpId"], | ||||
|             "testserver", | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             challenge["timeout"], | ||||
|             60000, | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             challenge["userVerification"], | ||||
|             "preferred", | ||||
|             challenge, | ||||
|             { | ||||
|                 "allowCredentials": [ | ||||
|                     { | ||||
|                         "id": "QKZ97ASJAOIDyipAs6mKUxDUZgDrWrbAsUb5leL7-oU", | ||||
|                         "type": "public-key", | ||||
|                     } | ||||
|                 ], | ||||
|                 "challenge": bytes_to_base64url(webauthn_challenge), | ||||
|                 "rpId": "testserver", | ||||
|                 "timeout": 60000, | ||||
|                 "userVerification": "preferred", | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_get_challenge_userless(self): | ||||
| @ -247,16 +228,18 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): | ||||
|             sign_count=0, | ||||
|             rp_id=generate_id(), | ||||
|         ) | ||||
|         plan = FlowPlan("") | ||||
|         stage_view = AuthenticatorValidateStageView( | ||||
|             FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request | ||||
|         challenge = get_webauthn_challenge_without_user(request, stage) | ||||
|         webauthn_challenge = request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] | ||||
|         self.assertEqual( | ||||
|             challenge, | ||||
|             { | ||||
|                 "allowCredentials": [], | ||||
|                 "challenge": bytes_to_base64url(webauthn_challenge), | ||||
|                 "rpId": "testserver", | ||||
|                 "timeout": 60000, | ||||
|                 "userVerification": "preferred", | ||||
|             }, | ||||
|         ) | ||||
|         challenge = get_webauthn_challenge_without_user(stage_view, stage) | ||||
|         self.assertEqual(challenge["allowCredentials"], []) | ||||
|         self.assertIsNotNone(challenge["challenge"]) | ||||
|         self.assertEqual(challenge["rpId"], "testserver") | ||||
|         self.assertEqual(challenge["timeout"], 60000) | ||||
|         self.assertEqual(challenge["userVerification"], "preferred") | ||||
|  | ||||
|     def test_validate_challenge_unrestricted(self): | ||||
|         """Test webauthn authentication (unrestricted webauthn device)""" | ||||
| @ -292,10 +275,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): | ||||
|                 "last_used": None, | ||||
|             } | ||||
|         ] | ||||
|         plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes( | ||||
|         session[SESSION_KEY_PLAN] = plan | ||||
|         session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes( | ||||
|             "aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ" | ||||
|         ) | ||||
|         session[SESSION_KEY_PLAN] = plan | ||||
|         session.save() | ||||
|  | ||||
|         response = self.client.post( | ||||
| @ -369,10 +352,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): | ||||
|                 "last_used": None, | ||||
|             } | ||||
|         ] | ||||
|         plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes( | ||||
|         session[SESSION_KEY_PLAN] = plan | ||||
|         session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes( | ||||
|             "aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ" | ||||
|         ) | ||||
|         session[SESSION_KEY_PLAN] = plan | ||||
|         session.save() | ||||
|  | ||||
|         response = self.client.post( | ||||
| @ -450,10 +433,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): | ||||
|                 "last_used": None, | ||||
|             } | ||||
|         ] | ||||
|         plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes( | ||||
|         session[SESSION_KEY_PLAN] = plan | ||||
|         session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes( | ||||
|             "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA" | ||||
|         ) | ||||
|         session[SESSION_KEY_PLAN] = plan | ||||
|         session.save() | ||||
|  | ||||
|         response = self.client.post( | ||||
| @ -513,14 +496,17 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase): | ||||
|             not_configured_action=NotConfiguredAction.CONFIGURE, | ||||
|             device_classes=[DeviceClasses.WEBAUTHN], | ||||
|         ) | ||||
|         plan = FlowPlan(flow.pk.hex) | ||||
|         plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes( | ||||
|             "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA" | ||||
|         stage_view = AuthenticatorValidateStageView( | ||||
|             FlowExecutorView(flow=flow, current_stage=stage), request=request | ||||
|         ) | ||||
|         request = get_request("/") | ||||
|         request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes( | ||||
|             "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA" | ||||
|         ) | ||||
|         request.session.save() | ||||
|  | ||||
|         stage_view = AuthenticatorValidateStageView( | ||||
|             FlowExecutorView(flow=flow, current_stage=stage, plan=plan), request=request | ||||
|             FlowExecutorView(flow=flow, current_stage=stage), request=request | ||||
|         ) | ||||
|         request.META["SERVER_NAME"] = "localhost" | ||||
|         request.META["SERVER_PORT"] = "9000" | ||||
|  | ||||
| @ -25,7 +25,6 @@ class AuthenticatorWebAuthnStageSerializer(StageSerializer): | ||||
|             "resident_key_requirement", | ||||
|             "device_type_restrictions", | ||||
|             "device_type_restrictions_obj", | ||||
|             "max_attempts", | ||||
|         ] | ||||
|  | ||||
|  | ||||
|  | ||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @ -1,21 +0,0 @@ | ||||
| # Generated by Django 5.1.11 on 2025-06-13 22:41 | ||||
|  | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ( | ||||
|             "authentik_stages_authenticator_webauthn", | ||||
|             "0012_webauthndevice_created_webauthndevice_last_updated_and_more", | ||||
|         ), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.AddField( | ||||
|             model_name="authenticatorwebauthnstage", | ||||
|             name="max_attempts", | ||||
|             field=models.PositiveIntegerField(default=0), | ||||
|         ), | ||||
|     ] | ||||
| @ -84,8 +84,6 @@ class AuthenticatorWebAuthnStage(ConfigurableStage, FriendlyNamedStage, Stage): | ||||
|  | ||||
|     device_type_restrictions = models.ManyToManyField("WebAuthnDeviceType", blank=True) | ||||
|  | ||||
|     max_attempts = models.PositiveIntegerField(default=0) | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> type[BaseSerializer]: | ||||
|         from authentik.stages.authenticator_webauthn.api.stages import ( | ||||
|  | ||||
| @ -5,13 +5,12 @@ from uuid import UUID | ||||
|  | ||||
| from django.http import HttpRequest, HttpResponse | ||||
| from django.http.request import QueryDict | ||||
| from django.utils.translation import gettext as __ | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from rest_framework.fields import CharField | ||||
| from rest_framework.serializers import ValidationError | ||||
| from webauthn import options_to_json | ||||
| from webauthn.helpers.bytes_to_base64url import bytes_to_base64url | ||||
| from webauthn.helpers.exceptions import WebAuthnException | ||||
| from webauthn.helpers.exceptions import InvalidRegistrationResponse | ||||
| from webauthn.helpers.structs import ( | ||||
|     AttestationConveyancePreference, | ||||
|     AuthenticatorAttachment, | ||||
| @ -42,8 +41,7 @@ from authentik.stages.authenticator_webauthn.models import ( | ||||
| ) | ||||
| from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id | ||||
|  | ||||
| PLAN_CONTEXT_WEBAUTHN_CHALLENGE = "goauthentik.io/stages/authenticator_webauthn/challenge" | ||||
| PLAN_CONTEXT_WEBAUTHN_ATTEMPT = "goauthentik.io/stages/authenticator_webauthn/attempt" | ||||
| SESSION_KEY_WEBAUTHN_CHALLENGE = "authentik/stages/authenticator_webauthn/challenge" | ||||
|  | ||||
|  | ||||
| class AuthenticatorWebAuthnChallenge(WithUserInfoChallenge): | ||||
| @ -64,7 +62,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse): | ||||
|  | ||||
|     def validate_response(self, response: dict) -> dict: | ||||
|         """Validate webauthn challenge response""" | ||||
|         challenge = self.stage.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] | ||||
|         challenge = self.request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] | ||||
|  | ||||
|         try: | ||||
|             registration: VerifiedRegistration = verify_registration_response( | ||||
| @ -73,7 +71,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse): | ||||
|                 expected_rp_id=get_rp_id(self.request), | ||||
|                 expected_origin=get_origin(self.request), | ||||
|             ) | ||||
|         except WebAuthnException as exc: | ||||
|         except InvalidRegistrationResponse as exc: | ||||
|             self.stage.logger.warning("registration failed", exc=exc) | ||||
|             raise ValidationError(f"Registration failed. Error: {exc}") from None | ||||
|  | ||||
| @ -116,10 +114,9 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView): | ||||
|     response_class = AuthenticatorWebAuthnChallengeResponse | ||||
|  | ||||
|     def get_challenge(self, *args, **kwargs) -> Challenge: | ||||
|         # clear session variables prior to starting a new registration | ||||
|         self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None) | ||||
|         stage: AuthenticatorWebAuthnStage = self.executor.current_stage | ||||
|         self.executor.plan.context.setdefault(PLAN_CONTEXT_WEBAUTHN_ATTEMPT, 0) | ||||
|         # clear flow variables prior to starting a new registration | ||||
|         self.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None) | ||||
|         user = self.get_pending_user() | ||||
|  | ||||
|         # library accepts none so we store null in the database, but if there is a value | ||||
| @ -142,7 +139,8 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView): | ||||
|             attestation=AttestationConveyancePreference.DIRECT, | ||||
|         ) | ||||
|  | ||||
|         self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = registration_options.challenge | ||||
|         self.request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = registration_options.challenge | ||||
|         self.request.session.save() | ||||
|         return AuthenticatorWebAuthnChallenge( | ||||
|             data={ | ||||
|                 "registration": loads(options_to_json(registration_options)), | ||||
| @ -155,24 +153,6 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView): | ||||
|         response.user = self.get_pending_user() | ||||
|         return response | ||||
|  | ||||
|     def challenge_invalid(self, response): | ||||
|         stage: AuthenticatorWebAuthnStage = self.executor.current_stage | ||||
|         self.executor.plan.context.setdefault(PLAN_CONTEXT_WEBAUTHN_ATTEMPT, 0) | ||||
|         self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_ATTEMPT] += 1 | ||||
|         if ( | ||||
|             stage.max_attempts > 0 | ||||
|             and self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_ATTEMPT] >= stage.max_attempts | ||||
|         ): | ||||
|             return self.executor.stage_invalid( | ||||
|                 __( | ||||
|                     "Exceeded maximum attempts. " | ||||
|                     "Contact your {brand} administrator for help.".format( | ||||
|                         brand=self.request.brand.branding_title | ||||
|                     ) | ||||
|                 ) | ||||
|             ) | ||||
|         return super().challenge_invalid(response) | ||||
|  | ||||
|     def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: | ||||
|         # Webauthn Challenge has already been validated | ||||
|         webauthn_credential: VerifiedRegistration = response.validated_data["response"] | ||||
| @ -199,3 +179,6 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView): | ||||
|         else: | ||||
|             return self.executor.stage_invalid("Device with Credential ID already exists.") | ||||
|         return self.executor.stage_ok() | ||||
|  | ||||
|     def cleanup(self): | ||||
|         self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None) | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	