Compare commits
	
		
			2 Commits
		
	
	
		
			enterprise
			...
			providers/
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 7549a6b83d | |||
| bb45b714e2 | 
| @ -1,5 +1,5 @@ | ||||
| [bumpversion] | ||||
| current_version = 2025.6.3 | ||||
| current_version = 2025.6.1 | ||||
| tag = True | ||||
| commit = True | ||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | ||||
| @ -21,8 +21,6 @@ optional_value = final | ||||
|  | ||||
| [bumpversion:file:package.json] | ||||
|  | ||||
| [bumpversion:file:package-lock.json] | ||||
|  | ||||
| [bumpversion:file:docker-compose.yml] | ||||
|  | ||||
| [bumpversion:file:schema.yml] | ||||
| @ -33,4 +31,6 @@ optional_value = final | ||||
|  | ||||
| [bumpversion:file:internal/constants/constants.go] | ||||
|  | ||||
| [bumpversion:file:web/src/common/constants.ts] | ||||
|  | ||||
| [bumpversion:file:lifecycle/aws/template.yaml] | ||||
|  | ||||
| @ -7,9 +7,6 @@ charset = utf-8 | ||||
| trim_trailing_whitespace = true | ||||
| insert_final_newline = true | ||||
|  | ||||
| [*.toml] | ||||
| indent_size = 2 | ||||
|  | ||||
| [*.html] | ||||
| indent_size = 2 | ||||
|  | ||||
|  | ||||
| @ -38,8 +38,6 @@ jobs: | ||||
|       # Needed for attestation | ||||
|       id-token: write | ||||
|       attestations: write | ||||
|       # Needed for checkout | ||||
|       contents: read | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|       - uses: docker/setup-qemu-action@v3.6.0 | ||||
|  | ||||
							
								
								
									
										3
									
								
								.github/workflows/ci-main-daily.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/ci-main-daily.yml
									
									
									
									
										vendored
									
									
								
							| @ -9,15 +9,14 @@ on: | ||||
|  | ||||
| jobs: | ||||
|   test-container: | ||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} | ||||
|     runs-on: ubuntu-latest | ||||
|     strategy: | ||||
|       fail-fast: false | ||||
|       matrix: | ||||
|         version: | ||||
|           - docs | ||||
|           - version-2025-4 | ||||
|           - version-2025-2 | ||||
|           - version-2024-12 | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|       - run: | | ||||
|  | ||||
							
								
								
									
										6
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							| @ -202,7 +202,7 @@ jobs: | ||||
|         uses: actions/cache@v4 | ||||
|         with: | ||||
|           path: web/dist | ||||
|           key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b | ||||
|           key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b | ||||
|       - name: prepare web ui | ||||
|         if: steps.cache-web.outputs.cache-hit != 'true' | ||||
|         working-directory: web | ||||
| @ -247,13 +247,11 @@ jobs: | ||||
|       # Needed for attestation | ||||
|       id-token: write | ||||
|       attestations: write | ||||
|       # Needed for checkout | ||||
|       contents: read | ||||
|     needs: ci-core-mark | ||||
|     uses: ./.github/workflows/_reusable-docker-build.yaml | ||||
|     secrets: inherit | ||||
|     with: | ||||
|       image_name: ${{ github.repository == 'goauthentik/authentik-internal' && 'ghcr.io/goauthentik/internal-server' || 'ghcr.io/goauthentik/dev-server' }} | ||||
|       image_name: ghcr.io/goauthentik/dev-server | ||||
|       release: false | ||||
|   pr-comment: | ||||
|     needs: | ||||
|  | ||||
							
								
								
									
										1
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/ci-outpost.yml
									
									
									
									
										vendored
									
									
								
							| @ -59,7 +59,6 @@ jobs: | ||||
|         with: | ||||
|           jobs: ${{ toJSON(needs) }} | ||||
|   build-container: | ||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} | ||||
|     timeout-minutes: 120 | ||||
|     needs: | ||||
|       - ci-outpost-mark | ||||
|  | ||||
							
								
								
									
										24
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										24
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							| @ -41,29 +41,7 @@ jobs: | ||||
|       - name: test | ||||
|         working-directory: website/ | ||||
|         run: npm test | ||||
|   build: | ||||
|     runs-on: ubuntu-latest | ||||
|     name: ${{ matrix.job }} | ||||
|     strategy: | ||||
|       fail-fast: false | ||||
|       matrix: | ||||
|         job: | ||||
|           - build | ||||
|           - build:integrations | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|       - uses: actions/setup-node@v4 | ||||
|         with: | ||||
|           node-version-file: website/package.json | ||||
|           cache: "npm" | ||||
|           cache-dependency-path: website/package-lock.json | ||||
|       - working-directory: website/ | ||||
|         run: npm ci | ||||
|       - name: build | ||||
|         working-directory: website/ | ||||
|         run: npm run ${{ matrix.job }} | ||||
|   build-container: | ||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} | ||||
|     runs-on: ubuntu-latest | ||||
|     permissions: | ||||
|       # Needed to upload container images to ghcr.io | ||||
| @ -116,11 +94,9 @@ jobs: | ||||
|     needs: | ||||
|       - lint | ||||
|       - test | ||||
|       - build | ||||
|       - build-container | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - uses: re-actors/alls-green@release/v1 | ||||
|         with: | ||||
|           jobs: ${{ toJSON(needs) }} | ||||
|           allowed-skips: ${{ github.repository == 'goauthentik/authentik-internal' && 'build-container' || '[]' }} | ||||
|  | ||||
							
								
								
									
										2
									
								
								.github/workflows/codeql-analysis.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/codeql-analysis.yml
									
									
									
									
										vendored
									
									
								
							| @ -2,7 +2,7 @@ name: "CodeQL" | ||||
|  | ||||
| on: | ||||
|   push: | ||||
|     branches: [main, next, version*] | ||||
|     branches: [main, "*", next, version*] | ||||
|   pull_request: | ||||
|     branches: [main] | ||||
|   schedule: | ||||
|  | ||||
							
								
								
									
										21
									
								
								.github/workflows/repo-mirror-cleanup.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										21
									
								
								.github/workflows/repo-mirror-cleanup.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,21 +0,0 @@ | ||||
| name: "authentik-repo-mirror-cleanup" | ||||
|  | ||||
| on: | ||||
|   workflow_dispatch: | ||||
|  | ||||
| jobs: | ||||
|   to_internal: | ||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|         with: | ||||
|           fetch-depth: 0 | ||||
|       - if: ${{ env.MIRROR_KEY != '' }} | ||||
|         uses: BeryJu/repository-mirroring-action@5cf300935bc2e068f73ea69bcc411a8a997208eb | ||||
|         with: | ||||
|           target_repo_url: git@github.com:goauthentik/authentik-internal.git | ||||
|           ssh_private_key: ${{ secrets.GH_MIRROR_KEY }} | ||||
|           args: --tags --force --prune | ||||
|         env: | ||||
|           MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }} | ||||
							
								
								
									
										9
									
								
								.github/workflows/repo-mirror.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								.github/workflows/repo-mirror.yml
									
									
									
									
										vendored
									
									
								
							| @ -11,10 +11,11 @@ jobs: | ||||
|         with: | ||||
|           fetch-depth: 0 | ||||
|       - if: ${{ env.MIRROR_KEY != '' }} | ||||
|         uses: BeryJu/repository-mirroring-action@5cf300935bc2e068f73ea69bcc411a8a997208eb | ||||
|         uses: pixta-dev/repository-mirroring-action@v1 | ||||
|         with: | ||||
|           target_repo_url: git@github.com:goauthentik/authentik-internal.git | ||||
|           ssh_private_key: ${{ secrets.GH_MIRROR_KEY }} | ||||
|           args: --tags --force | ||||
|           target_repo_url: | ||||
|             git@github.com:goauthentik/authentik-internal.git | ||||
|           ssh_private_key: | ||||
|             ${{ secrets.GH_MIRROR_KEY }} | ||||
|         env: | ||||
|           MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }} | ||||
|  | ||||
| @ -16,7 +16,6 @@ env: | ||||
|  | ||||
| jobs: | ||||
|   compile: | ||||
|     if: ${{ github.repository != 'goauthentik/authentik-internal' }} | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - id: generate_token | ||||
|  | ||||
							
								
								
									
										4
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
								
							| @ -6,15 +6,13 @@ | ||||
|         "!Context scalar", | ||||
|         "!Enumerate sequence", | ||||
|         "!Env scalar", | ||||
|         "!Env sequence", | ||||
|         "!Find sequence", | ||||
|         "!Format sequence", | ||||
|         "!If sequence", | ||||
|         "!Index scalar", | ||||
|         "!KeyOf scalar", | ||||
|         "!Value scalar", | ||||
|         "!AtIndex scalar", | ||||
|         "!ParseJSON scalar" | ||||
|         "!AtIndex scalar" | ||||
|     ], | ||||
|     "typescript.preferences.importModuleSpecifier": "non-relative", | ||||
|     "typescript.preferences.importModuleSpecifierEnding": "index", | ||||
|  | ||||
| @ -75,9 +75,9 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \ | ||||
|     /bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" | ||||
|  | ||||
| # Stage 4: Download uv | ||||
| FROM ghcr.io/astral-sh/uv:0.7.17 AS uv | ||||
| FROM ghcr.io/astral-sh/uv:0.7.13 AS uv | ||||
| # Stage 5: Base python image | ||||
| FROM ghcr.io/goauthentik/fips-python:3.13.5-slim-bookworm-fips AS python-base | ||||
| FROM ghcr.io/goauthentik/fips-python:3.13.4-slim-bookworm-fips AS python-base | ||||
|  | ||||
| ENV VENV_PATH="/ak-root/.venv" \ | ||||
|     PATH="/lifecycle:/ak-root/.venv/bin:$PATH" \ | ||||
|  | ||||
							
								
								
									
										10
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								Makefile
									
									
									
									
									
								
							| @ -86,10 +86,6 @@ dev-create-db: | ||||
|  | ||||
| dev-reset: dev-drop-db dev-create-db migrate  ## Drop and restore the Authentik PostgreSQL instance to a "fresh install" state. | ||||
|  | ||||
| update-test-mmdb:  ## Update test GeoIP and ASN Databases | ||||
| 	curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-ASN-Test.mmdb -o ${PWD}/tests/GeoLite2-ASN-Test.mmdb | ||||
| 	curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-City-Test.mmdb -o ${PWD}/tests/GeoLite2-City-Test.mmdb | ||||
|  | ||||
| ######################### | ||||
| ## API Schema | ||||
| ######################### | ||||
| @ -150,9 +146,9 @@ gen-client-ts: gen-clean-ts  ## Build and install the authentik API for Typescri | ||||
| 		--additional-properties=npmVersion=${NPM_VERSION} \ | ||||
| 		--git-repo-id authentik \ | ||||
| 		--git-user-id goauthentik | ||||
|  | ||||
| 	cd ${PWD}/${GEN_API_TS} && npm link | ||||
| 	cd ${PWD}/web && npm link @goauthentik/api | ||||
| 	mkdir -p web/node_modules/@goauthentik/api | ||||
| 	cd ${PWD}/${GEN_API_TS} && npm i | ||||
| 	\cp -rf ${PWD}/${GEN_API_TS}/* web/node_modules/@goauthentik/api | ||||
|  | ||||
| gen-client-py: gen-clean-py ## Build and install the authentik API for Python | ||||
| 	docker run \ | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
|  | ||||
| from os import environ | ||||
|  | ||||
| __version__ = "2025.6.3" | ||||
| __version__ = "2025.6.1" | ||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -37,7 +37,6 @@ entries: | ||||
|     - attrs: | ||||
|           attributes: | ||||
|               env_null: !Env [bar-baz, null] | ||||
|               json_parse: !ParseJSON '{"foo": "bar"}' | ||||
|               policy_pk1: | ||||
|                   !Format [ | ||||
|                       "%s-%s", | ||||
|  | ||||
| @ -35,6 +35,6 @@ def blueprint_tester(file_name: Path) -> Callable: | ||||
|  | ||||
|  | ||||
| for blueprint_file in Path("blueprints/").glob("**/*.yaml"): | ||||
|     if "local" in str(blueprint_file) or "testing" in str(blueprint_file): | ||||
|     if "local" in str(blueprint_file): | ||||
|         continue | ||||
|     setattr(TestPackaged, f"test_blueprint_{blueprint_file}", blueprint_tester(blueprint_file)) | ||||
|  | ||||
| @ -5,6 +5,7 @@ from collections.abc import Callable | ||||
| from django.apps import apps | ||||
| from django.test import TestCase | ||||
|  | ||||
| from authentik.blueprints.v1.importer import is_model_allowed | ||||
| from authentik.lib.models import SerializerModel | ||||
| from authentik.providers.oauth2.models import RefreshToken | ||||
|  | ||||
| @ -21,13 +22,10 @@ def serializer_tester_factory(test_model: type[SerializerModel]) -> Callable: | ||||
|             return | ||||
|         model_class = test_model() | ||||
|         self.assertTrue(isinstance(model_class, SerializerModel)) | ||||
|         # Models that have subclasses don't have to have a serializer | ||||
|         if len(test_model.__subclasses__()) > 0: | ||||
|             return | ||||
|         self.assertIsNotNone(model_class.serializer) | ||||
|         if model_class.serializer.Meta().model == RefreshToken: | ||||
|             return | ||||
|         self.assertTrue(issubclass(test_model, model_class.serializer.Meta().model)) | ||||
|         self.assertEqual(model_class.serializer.Meta().model, test_model) | ||||
|  | ||||
|     return tester | ||||
|  | ||||
| @ -36,6 +34,6 @@ for app in apps.get_app_configs(): | ||||
|     if not app.label.startswith("authentik"): | ||||
|         continue | ||||
|     for model in app.get_models(): | ||||
|         if not issubclass(model, SerializerModel): | ||||
|         if not is_model_allowed(model): | ||||
|             continue | ||||
|         setattr(TestModels, f"test_{app.label}_{model.__name__}", serializer_tester_factory(model)) | ||||
|  | ||||
| @ -215,7 +215,6 @@ class TestBlueprintsV1(TransactionTestCase): | ||||
|                     }, | ||||
|                     "nested_context": "context-nested-value", | ||||
|                     "env_null": None, | ||||
|                     "json_parse": {"foo": "bar"}, | ||||
|                     "at_index_sequence": "foo", | ||||
|                     "at_index_sequence_default": "non existent", | ||||
|                     "at_index_mapping": 2, | ||||
|  | ||||
| @ -6,7 +6,6 @@ from copy import copy | ||||
| from dataclasses import asdict, dataclass, field, is_dataclass | ||||
| from enum import Enum | ||||
| from functools import reduce | ||||
| from json import JSONDecodeError, loads | ||||
| from operator import ixor | ||||
| from os import getenv | ||||
| from typing import Any, Literal, Union | ||||
| @ -292,22 +291,6 @@ class Context(YAMLTag): | ||||
|         return value | ||||
|  | ||||
|  | ||||
| class ParseJSON(YAMLTag): | ||||
|     """Parse JSON from context/env/etc value""" | ||||
|  | ||||
|     raw: str | ||||
|  | ||||
|     def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None: | ||||
|         super().__init__() | ||||
|         self.raw = node.value | ||||
|  | ||||
|     def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any: | ||||
|         try: | ||||
|             return loads(self.raw) | ||||
|         except JSONDecodeError as exc: | ||||
|             raise EntryInvalidError.from_entry(exc, entry) from exc | ||||
|  | ||||
|  | ||||
| class Format(YAMLTag): | ||||
|     """Format a string""" | ||||
|  | ||||
| @ -683,7 +666,6 @@ class BlueprintLoader(SafeLoader): | ||||
|         self.add_constructor("!Value", Value) | ||||
|         self.add_constructor("!Index", Index) | ||||
|         self.add_constructor("!AtIndex", AtIndex) | ||||
|         self.add_constructor("!ParseJSON", ParseJSON) | ||||
|  | ||||
|  | ||||
| class EntryInvalidError(SentryIgnoredException): | ||||
|  | ||||
| @ -43,7 +43,6 @@ from authentik.core.models import ( | ||||
| ) | ||||
| from authentik.enterprise.license import LicenseKey | ||||
| from authentik.enterprise.models import LicenseUsage | ||||
| from authentik.enterprise.providers.apple_psso.models import AppleNonce | ||||
| from authentik.enterprise.providers.google_workspace.models import ( | ||||
|     GoogleWorkspaceProviderGroup, | ||||
|     GoogleWorkspaceProviderUser, | ||||
| @ -136,7 +135,6 @@ def excluded_models() -> list[type[Model]]: | ||||
|         EndpointDeviceConnection, | ||||
|         DeviceToken, | ||||
|         StreamEvent, | ||||
|         AppleNonce, | ||||
|     ) | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -1,6 +1,8 @@ | ||||
| """Authenticator Devices API Views""" | ||||
|  | ||||
| from drf_spectacular.utils import extend_schema | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from drf_spectacular.types import OpenApiTypes | ||||
| from drf_spectacular.utils import OpenApiParameter, extend_schema | ||||
| from guardian.shortcuts import get_objects_for_user | ||||
| from rest_framework.fields import ( | ||||
|     BooleanField, | ||||
| @ -13,7 +15,6 @@ from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
| from rest_framework.viewsets import ViewSet | ||||
|  | ||||
| from authentik.core.api.users import ParamUserSerializer | ||||
| from authentik.core.api.utils import MetaNameSerializer | ||||
| from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import EndpointDevice | ||||
| from authentik.stages.authenticator import device_classes, devices_for_user | ||||
| @ -22,7 +23,7 @@ from authentik.stages.authenticator_webauthn.models import WebAuthnDevice | ||||
|  | ||||
|  | ||||
| class DeviceSerializer(MetaNameSerializer): | ||||
|     """Serializer for authenticator devices""" | ||||
|     """Serializer for Duo authenticator devices""" | ||||
|  | ||||
|     pk = CharField() | ||||
|     name = CharField() | ||||
| @ -32,27 +33,22 @@ class DeviceSerializer(MetaNameSerializer): | ||||
|     last_updated = DateTimeField(read_only=True) | ||||
|     last_used = DateTimeField(read_only=True, allow_null=True) | ||||
|     extra_description = SerializerMethodField() | ||||
|     external_id = SerializerMethodField() | ||||
|  | ||||
|     def get_type(self, instance: Device) -> str: | ||||
|         """Get type of device""" | ||||
|         return instance._meta.label | ||||
|  | ||||
|     def get_extra_description(self, instance: Device) -> str | None: | ||||
|     def get_extra_description(self, instance: Device) -> str: | ||||
|         """Get extra description""" | ||||
|         if isinstance(instance, WebAuthnDevice): | ||||
|             return instance.device_type.description if instance.device_type else None | ||||
|             return ( | ||||
|                 instance.device_type.description | ||||
|                 if instance.device_type | ||||
|                 else _("Extra description not available") | ||||
|             ) | ||||
|         if isinstance(instance, EndpointDevice): | ||||
|             return instance.data.get("deviceSignals", {}).get("deviceModel") | ||||
|         return None | ||||
|  | ||||
|     def get_external_id(self, instance: Device) -> str | None: | ||||
|         """Get external Device ID""" | ||||
|         if isinstance(instance, WebAuthnDevice): | ||||
|             return instance.device_type.aaguid if instance.device_type else None | ||||
|         if isinstance(instance, EndpointDevice): | ||||
|             return instance.data.get("deviceSignals", {}).get("deviceModel") | ||||
|         return None | ||||
|         return "" | ||||
|  | ||||
|  | ||||
| class DeviceViewSet(ViewSet): | ||||
| @ -61,6 +57,7 @@ class DeviceViewSet(ViewSet): | ||||
|     serializer_class = DeviceSerializer | ||||
|     permission_classes = [IsAuthenticated] | ||||
|  | ||||
|     @extend_schema(responses={200: DeviceSerializer(many=True)}) | ||||
|     def list(self, request: Request) -> Response: | ||||
|         """Get all devices for current user""" | ||||
|         devices = devices_for_user(request.user) | ||||
| @ -82,11 +79,18 @@ class AdminDeviceViewSet(ViewSet): | ||||
|             yield from device_set | ||||
|  | ||||
|     @extend_schema( | ||||
|         parameters=[ParamUserSerializer], | ||||
|         parameters=[ | ||||
|             OpenApiParameter( | ||||
|                 name="user", | ||||
|                 location=OpenApiParameter.QUERY, | ||||
|                 type=OpenApiTypes.INT, | ||||
|             ) | ||||
|         ], | ||||
|         responses={200: DeviceSerializer(many=True)}, | ||||
|     ) | ||||
|     def list(self, request: Request) -> Response: | ||||
|         """Get all devices for current user""" | ||||
|         args = ParamUserSerializer(data=request.query_params) | ||||
|         args.is_valid(raise_exception=True) | ||||
|         return Response(DeviceSerializer(self.get_devices(**args.validated_data), many=True).data) | ||||
|         kwargs = {} | ||||
|         if "user" in request.query_params: | ||||
|             kwargs = {"user": request.query_params["user"]} | ||||
|         return Response(DeviceSerializer(self.get_devices(**kwargs), many=True).data) | ||||
|  | ||||
| @ -90,12 +90,6 @@ from authentik.stages.email.utils import TemplateEmailMessage | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| class ParamUserSerializer(PassiveSerializer): | ||||
|     """Partial serializer for query parameters to select a user""" | ||||
|  | ||||
|     user = PrimaryKeyRelatedField(queryset=User.objects.all().exclude_anonymous(), required=False) | ||||
|  | ||||
|  | ||||
| class UserGroupSerializer(ModelSerializer): | ||||
|     """Simplified Group Serializer for user's groups""" | ||||
|  | ||||
| @ -392,23 +386,8 @@ class UserViewSet(UsedByMixin, ModelViewSet): | ||||
|     queryset = User.objects.none() | ||||
|     ordering = ["username"] | ||||
|     serializer_class = UserSerializer | ||||
|     filterset_class = UsersFilter | ||||
|     search_fields = ["username", "name", "is_active", "email", "uuid", "attributes"] | ||||
|  | ||||
|     def get_ql_fields(self): | ||||
|         from djangoql.schema import BoolField, StrField | ||||
|  | ||||
|         from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField | ||||
|  | ||||
|         return [ | ||||
|             StrField(User, "username"), | ||||
|             StrField(User, "name"), | ||||
|             StrField(User, "email"), | ||||
|             StrField(User, "path"), | ||||
|             BoolField(User, "is_active", nullable=True), | ||||
|             ChoiceSearchField(User, "type"), | ||||
|             JSONSearchField(User, "attributes", suggest_nested=False), | ||||
|         ] | ||||
|     filterset_class = UsersFilter | ||||
|  | ||||
|     def get_queryset(self): | ||||
|         base_qs = User.objects.all().exclude_anonymous() | ||||
|  | ||||
| @ -2,7 +2,6 @@ | ||||
|  | ||||
| from typing import Any | ||||
|  | ||||
| from django.db import models | ||||
| from django.db.models import Model | ||||
| from drf_spectacular.extensions import OpenApiSerializerFieldExtension | ||||
| from drf_spectacular.plumbing import build_basic_type | ||||
| @ -31,27 +30,7 @@ def is_dict(value: Any): | ||||
|     raise ValidationError("Value must be a dictionary, and not have any duplicate keys.") | ||||
|  | ||||
|  | ||||
| class JSONDictField(JSONField): | ||||
|     """JSON Field which only allows dictionaries""" | ||||
|  | ||||
|     default_validators = [is_dict] | ||||
|  | ||||
|  | ||||
| class JSONExtension(OpenApiSerializerFieldExtension): | ||||
|     """Generate API Schema for JSON fields as""" | ||||
|  | ||||
|     target_class = "authentik.core.api.utils.JSONDictField" | ||||
|  | ||||
|     def map_serializer_field(self, auto_schema, direction): | ||||
|         return build_basic_type(OpenApiTypes.OBJECT) | ||||
|  | ||||
|  | ||||
| class ModelSerializer(BaseModelSerializer): | ||||
|  | ||||
|     # By default, JSON fields we have are used to store dictionaries | ||||
|     serializer_field_mapping = BaseModelSerializer.serializer_field_mapping.copy() | ||||
|     serializer_field_mapping[models.JSONField] = JSONDictField | ||||
|  | ||||
|     def create(self, validated_data): | ||||
|         instance = super().create(validated_data) | ||||
|  | ||||
| @ -92,6 +71,21 @@ class ModelSerializer(BaseModelSerializer): | ||||
|         return instance | ||||
|  | ||||
|  | ||||
| class JSONDictField(JSONField): | ||||
|     """JSON Field which only allows dictionaries""" | ||||
|  | ||||
|     default_validators = [is_dict] | ||||
|  | ||||
|  | ||||
| class JSONExtension(OpenApiSerializerFieldExtension): | ||||
|     """Generate API Schema for JSON fields as""" | ||||
|  | ||||
|     target_class = "authentik.core.api.utils.JSONDictField" | ||||
|  | ||||
|     def map_serializer_field(self, auto_schema, direction): | ||||
|         return build_basic_type(OpenApiTypes.OBJECT) | ||||
|  | ||||
|  | ||||
| class PassiveSerializer(Serializer): | ||||
|     """Base serializer class which doesn't implement create/update methods""" | ||||
|  | ||||
|  | ||||
| @ -13,6 +13,7 @@ class Command(TenantCommand): | ||||
|         parser.add_argument("usernames", nargs="*", type=str) | ||||
|  | ||||
|     def handle_per_tenant(self, **options): | ||||
|         print(options) | ||||
|         new_type = UserTypes(options["type"]) | ||||
|         qs = ( | ||||
|             User.objects.exclude_anonymous() | ||||
|  | ||||
| @ -18,7 +18,7 @@ from django.http import HttpRequest | ||||
| from django.utils.functional import SimpleLazyObject, cached_property | ||||
| from django.utils.timezone import now | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from django_cte import CTE, with_cte | ||||
| from django_cte import CTEQuerySet, With | ||||
| from guardian.conf import settings | ||||
| from guardian.mixins import GuardianUserMixin | ||||
| from model_utils.managers import InheritanceManager | ||||
| @ -136,7 +136,7 @@ class AttributesMixin(models.Model): | ||||
|         return instance, False | ||||
|  | ||||
|  | ||||
| class GroupQuerySet(QuerySet): | ||||
| class GroupQuerySet(CTEQuerySet): | ||||
|     def with_children_recursive(self): | ||||
|         """Recursively get all groups that have the current queryset as parents | ||||
|         or are indirectly related.""" | ||||
| @ -165,9 +165,9 @@ class GroupQuerySet(QuerySet): | ||||
|             ) | ||||
|  | ||||
|         # Build the recursive query, see above | ||||
|         cte = CTE.recursive(make_cte) | ||||
|         cte = With.recursive(make_cte) | ||||
|         # Return the result, as a usable queryset for Group. | ||||
|         return with_cte(cte, select=cte.join(Group, group_uuid=cte.col.group_uuid)) | ||||
|         return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte) | ||||
|  | ||||
|  | ||||
| class Group(SerializerModel, AttributesMixin): | ||||
| @ -1082,12 +1082,6 @@ class AuthenticatedSession(SerializerModel): | ||||
|  | ||||
|     user = models.ForeignKey(User, on_delete=models.CASCADE) | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> type[Serializer]: | ||||
|         from authentik.core.api.authenticated_sessions import AuthenticatedSessionSerializer | ||||
|  | ||||
|         return AuthenticatedSessionSerializer | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("Authenticated Session") | ||||
|         verbose_name_plural = _("Authenticated Sessions") | ||||
|  | ||||
| @ -114,7 +114,6 @@ class TestApplicationsAPI(APITestCase): | ||||
|         self.assertJSONEqual( | ||||
|             response.content.decode(), | ||||
|             { | ||||
|                 "autocomplete": {}, | ||||
|                 "pagination": { | ||||
|                     "next": 0, | ||||
|                     "previous": 0, | ||||
| @ -168,7 +167,6 @@ class TestApplicationsAPI(APITestCase): | ||||
|         self.assertJSONEqual( | ||||
|             response.content.decode(), | ||||
|             { | ||||
|                 "autocomplete": {}, | ||||
|                 "pagination": { | ||||
|                     "next": 0, | ||||
|                     "previous": 0, | ||||
|  | ||||
| @ -119,17 +119,17 @@ class TestTrimPasswordHistory(TestCase): | ||||
|             [ | ||||
|                 UserPasswordHistory( | ||||
|                     user=self.user, | ||||
|                     old_password="hunter1",  # nosec | ||||
|                     old_password="hunter1",  # nosec B106 | ||||
|                     created_at=_now - timedelta(days=3), | ||||
|                 ), | ||||
|                 UserPasswordHistory( | ||||
|                     user=self.user, | ||||
|                     old_password="hunter2",  # nosec | ||||
|                     old_password="hunter2",  # nosec B106 | ||||
|                     created_at=_now - timedelta(days=2), | ||||
|                 ), | ||||
|                 UserPasswordHistory( | ||||
|                     user=self.user, | ||||
|                     old_password="hunter3",  # nosec | ||||
|                     old_password="hunter3",  # nosec B106 | ||||
|                     created_at=_now, | ||||
|                 ), | ||||
|             ] | ||||
|  | ||||
| @ -1,32 +0,0 @@ | ||||
| """Apple Platform SSO Provider API Views""" | ||||
|  | ||||
| from rest_framework.viewsets import ModelViewSet | ||||
|  | ||||
| from authentik.core.api.providers import ProviderSerializer | ||||
| from authentik.core.api.used_by import UsedByMixin | ||||
| from authentik.enterprise.api import EnterpriseRequiredMixin | ||||
| from authentik.enterprise.providers.apple_psso.models import ApplePlatformSSOProvider | ||||
|  | ||||
|  | ||||
| class ApplePlatformSSOProviderSerializer(EnterpriseRequiredMixin, ProviderSerializer): | ||||
|     """ApplePlatformSSOProvider Serializer""" | ||||
|  | ||||
|     class Meta: | ||||
|         model = ApplePlatformSSOProvider | ||||
|         fields = [ | ||||
|             "pk", | ||||
|             "name", | ||||
|         ] | ||||
|         extra_kwargs = {} | ||||
|  | ||||
|  | ||||
| class ApplePlatformSSOProviderViewSet(UsedByMixin, ModelViewSet): | ||||
|     """ApplePlatformSSOProvider Viewset""" | ||||
|  | ||||
|     queryset = ApplePlatformSSOProvider.objects.all() | ||||
|     serializer_class = ApplePlatformSSOProviderSerializer | ||||
|     filterset_fields = [ | ||||
|         "name", | ||||
|     ] | ||||
|     search_fields = ["name"] | ||||
|     ordering = ["name"] | ||||
| @ -1,13 +0,0 @@ | ||||
| from authentik.enterprise.apps import EnterpriseConfig | ||||
|  | ||||
|  | ||||
| class AuthentikEnterpriseProviderApplePSSOConfig(EnterpriseConfig): | ||||
|  | ||||
|     name = "authentik.enterprise.providers.apple_psso" | ||||
|     label = "authentik_providers_apple_psso" | ||||
|     verbose_name = "authentik Enterprise.Providers.Apple Platform SSO" | ||||
|     default = True | ||||
|     mountpoints = { | ||||
|         "authentik.enterprise.providers.apple_psso.urls": "endpoint/apple/sso/", | ||||
|         "authentik.enterprise.providers.apple_psso.urls_root": "", | ||||
|     } | ||||
| @ -1,118 +0,0 @@ | ||||
| from base64 import urlsafe_b64encode | ||||
| from json import dumps | ||||
| from secrets import token_bytes | ||||
|  | ||||
| from cryptography.hazmat.backends import default_backend | ||||
| from cryptography.hazmat.primitives import hashes, serialization | ||||
| from cryptography.hazmat.primitives.asymmetric import ec | ||||
| from cryptography.hazmat.primitives.ciphers.aead import AESGCM | ||||
| from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash | ||||
| from django.http import HttpResponse | ||||
| from jwcrypto.common import base64url_decode, base64url_encode | ||||
|  | ||||
| from authentik.enterprise.providers.apple_psso.models import AppleDevice | ||||
|  | ||||
|  | ||||
| def length_prefixed(data: bytes) -> bytes: | ||||
|     length = len(data) | ||||
|     return length.to_bytes(4, "big") + data | ||||
|  | ||||
|  | ||||
| def build_apu(public_key: ec.EllipticCurvePublicKey): | ||||
|     # X9.63 representation: 0x04 || X || Y | ||||
|     public_numbers = public_key.public_numbers() | ||||
|  | ||||
|     x_bytes = public_numbers.x.to_bytes(32, "big") | ||||
|     y_bytes = public_numbers.y.to_bytes(32, "big") | ||||
|  | ||||
|     x963 = bytes([0x04]) + x_bytes + y_bytes | ||||
|  | ||||
|     result = length_prefixed(b"APPLE") + length_prefixed(x963) | ||||
|  | ||||
|     return result | ||||
|  | ||||
|  | ||||
| def encrypt_token_with_a256_gcm(body: dict, device_encryption_key: str, apv: bytes) -> str: | ||||
|     ephemeral_key = ec.generate_private_key(curve=ec.SECP256R1()) | ||||
|     device_public_key = serialization.load_pem_public_key( | ||||
|         device_encryption_key.encode(), backend=default_backend() | ||||
|     ) | ||||
|  | ||||
|     shared_secret_z = ephemeral_key.exchange(ec.ECDH(), device_public_key) | ||||
|  | ||||
|     apu = build_apu(ephemeral_key.public_key()) | ||||
|  | ||||
|     jwe_header = { | ||||
|         "enc": "A256GCM", | ||||
|         "kid": "ephemeralKey", | ||||
|         "epk": { | ||||
|             "x": base64url_encode( | ||||
|                 ephemeral_key.public_key().public_numbers().x.to_bytes(32, "big") | ||||
|             ), | ||||
|             "y": base64url_encode( | ||||
|                 ephemeral_key.public_key().public_numbers().y.to_bytes(32, "big") | ||||
|             ), | ||||
|             "kty": "EC", | ||||
|             "crv": "P-256", | ||||
|         }, | ||||
|         "typ": "platformsso-login-response+jwt", | ||||
|         "alg": "ECDH-ES", | ||||
|         "apu": base64url_encode(apu), | ||||
|         "apv": base64url_encode(apv), | ||||
|     } | ||||
|  | ||||
|     party_u_info = length_prefixed(apu) | ||||
|     party_v_info = length_prefixed(apv) | ||||
|     supp_pub_info = (256).to_bytes(4, "big") | ||||
|  | ||||
|     other_info = length_prefixed(b"A256GCM") + party_u_info + party_v_info + supp_pub_info | ||||
|  | ||||
|     ckdf = ConcatKDFHash( | ||||
|         algorithm=hashes.SHA256(), | ||||
|         length=32, | ||||
|         otherinfo=other_info, | ||||
|     ) | ||||
|  | ||||
|     derived_key = ckdf.derive(shared_secret_z) | ||||
|  | ||||
|     nonce = token_bytes(12) | ||||
|  | ||||
|     header_json = dumps(jwe_header, separators=(",", ":")).encode() | ||||
|     aad = urlsafe_b64encode(header_json).rstrip(b"=") | ||||
|  | ||||
|     aesgcm = AESGCM(derived_key) | ||||
|     ciphertext = aesgcm.encrypt(nonce, dumps(body).encode(), aad) | ||||
|  | ||||
|     ciphertext_body = ciphertext[:-16] | ||||
|     tag = ciphertext[-16:] | ||||
|  | ||||
|     # base64url encoding | ||||
|     protected_b64 = urlsafe_b64encode(header_json).rstrip(b"=") | ||||
|     iv_b64 = urlsafe_b64encode(nonce).rstrip(b"=") | ||||
|     ciphertext_b64 = urlsafe_b64encode(ciphertext_body).rstrip(b"=") | ||||
|     tag_b64 = urlsafe_b64encode(tag).rstrip(b"=") | ||||
|  | ||||
|     jwe_compact = b".".join( | ||||
|         [ | ||||
|             protected_b64, | ||||
|             b"", | ||||
|             iv_b64, | ||||
|             ciphertext_b64, | ||||
|             tag_b64, | ||||
|         ] | ||||
|     ) | ||||
|     return jwe_compact.decode() | ||||
|  | ||||
|  | ||||
| class JWEResponse(HttpResponse): | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         data: dict, | ||||
|         device: AppleDevice, | ||||
|         apv: str, | ||||
|     ): | ||||
|         super().__init__( | ||||
|             content=encrypt_token_with_a256_gcm(data, device.encryption_key, base64url_decode(apv)), | ||||
|             content_type="application/platformsso-login-response+jwt", | ||||
|         ) | ||||
| @ -1,36 +0,0 @@ | ||||
| # Generated by Django 5.1.11 on 2025-06-28 00:12 | ||||
|  | ||||
| import django.db.models.deletion | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     initial = True | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_providers_oauth2", "0028_migrate_session"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.CreateModel( | ||||
|             name="ApplePlatformSSOProvider", | ||||
|             fields=[ | ||||
|                 ( | ||||
|                     "oauth2provider_ptr", | ||||
|                     models.OneToOneField( | ||||
|                         auto_created=True, | ||||
|                         on_delete=django.db.models.deletion.CASCADE, | ||||
|                         parent_link=True, | ||||
|                         primary_key=True, | ||||
|                         serialize=False, | ||||
|                         to="authentik_providers_oauth2.oauth2provider", | ||||
|                     ), | ||||
|                 ), | ||||
|             ], | ||||
|             options={ | ||||
|                 "abstract": False, | ||||
|             }, | ||||
|             bases=("authentik_providers_oauth2.oauth2provider",), | ||||
|         ), | ||||
|     ] | ||||
| @ -1,94 +0,0 @@ | ||||
| # Generated by Django 5.1.11 on 2025-06-28 15:50 | ||||
|  | ||||
| import django.db.models.deletion | ||||
| import uuid | ||||
| from django.conf import settings | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_providers_apple_psso", "0001_initial"), | ||||
|         migrations.swappable_dependency(settings.AUTH_USER_MODEL), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.CreateModel( | ||||
|             name="AppleDevice", | ||||
|             fields=[ | ||||
|                 ( | ||||
|                     "endpoint_uuid", | ||||
|                     models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False), | ||||
|                 ), | ||||
|                 ("signing_key", models.TextField()), | ||||
|                 ("encryption_key", models.TextField()), | ||||
|                 ("key_exchange_key", models.TextField()), | ||||
|                 ("sign_key_id", models.TextField()), | ||||
|                 ("enc_key_id", models.TextField()), | ||||
|                 ("creation_time", models.DateTimeField(auto_now_add=True)), | ||||
|                 ( | ||||
|                     "provider", | ||||
|                     models.ForeignKey( | ||||
|                         on_delete=django.db.models.deletion.CASCADE, | ||||
|                         to="authentik_providers_apple_psso.appleplatformssoprovider", | ||||
|                     ), | ||||
|                 ), | ||||
|             ], | ||||
|         ), | ||||
|         migrations.CreateModel( | ||||
|             name="AppleDeviceUser", | ||||
|             fields=[ | ||||
|                 ("uuid", models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)), | ||||
|                 ("signing_key", models.TextField()), | ||||
|                 ("encryption_key", models.TextField()), | ||||
|                 ("sign_key_id", models.TextField()), | ||||
|                 ("enc_key_id", models.TextField()), | ||||
|                 ( | ||||
|                     "device", | ||||
|                     models.ForeignKey( | ||||
|                         on_delete=django.db.models.deletion.CASCADE, | ||||
|                         to="authentik_providers_apple_psso.appledevice", | ||||
|                     ), | ||||
|                 ), | ||||
|                 ( | ||||
|                     "user", | ||||
|                     models.ForeignKey( | ||||
|                         on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL | ||||
|                     ), | ||||
|                 ), | ||||
|             ], | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="appledevice", | ||||
|             name="users", | ||||
|             field=models.ManyToManyField( | ||||
|                 through="authentik_providers_apple_psso.AppleDeviceUser", | ||||
|                 to=settings.AUTH_USER_MODEL, | ||||
|             ), | ||||
|         ), | ||||
|         migrations.CreateModel( | ||||
|             name="AppleNonce", | ||||
|             fields=[ | ||||
|                 ( | ||||
|                     "id", | ||||
|                     models.AutoField( | ||||
|                         auto_created=True, primary_key=True, serialize=False, verbose_name="ID" | ||||
|                     ), | ||||
|                 ), | ||||
|                 ("expires", models.DateTimeField(default=None, null=True)), | ||||
|                 ("expiring", models.BooleanField(default=True)), | ||||
|                 ("nonce", models.TextField()), | ||||
|             ], | ||||
|             options={ | ||||
|                 "abstract": False, | ||||
|                 "indexes": [ | ||||
|                     models.Index(fields=["expires"], name="authentik_p_expires_47d534_idx"), | ||||
|                     models.Index(fields=["expiring"], name="authentik_p_expirin_87253e_idx"), | ||||
|                     models.Index( | ||||
|                         fields=["expiring", "expires"], name="authentik_p_expirin_20a7c9_idx" | ||||
|                     ), | ||||
|                 ], | ||||
|             }, | ||||
|         ), | ||||
|     ] | ||||
| @ -1,34 +0,0 @@ | ||||
| # Generated by Django 5.1.11 on 2025-06-28 22:18 | ||||
|  | ||||
| from django.db import migrations | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ( | ||||
|             "authentik_providers_apple_psso", | ||||
|             "0002_appledevice_appledeviceuser_appledevice_users_and_more", | ||||
|         ), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.RenameField( | ||||
|             model_name="appledeviceuser", | ||||
|             old_name="sign_key_id", | ||||
|             new_name="enclave_key_id", | ||||
|         ), | ||||
|         migrations.RenameField( | ||||
|             model_name="appledeviceuser", | ||||
|             old_name="signing_key", | ||||
|             new_name="secure_enclave_key", | ||||
|         ), | ||||
|         migrations.RemoveField( | ||||
|             model_name="appledeviceuser", | ||||
|             name="enc_key_id", | ||||
|         ), | ||||
|         migrations.RemoveField( | ||||
|             model_name="appledeviceuser", | ||||
|             name="encryption_key", | ||||
|         ), | ||||
|     ] | ||||
| @ -1,85 +0,0 @@ | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from django.db import models | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from rest_framework.serializers import Serializer | ||||
|  | ||||
| from authentik.core.models import ExpiringModel, User | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.providers.oauth2.models import ( | ||||
|     ClientTypes, | ||||
|     IssuerMode, | ||||
|     OAuth2Provider, | ||||
|     RedirectURI, | ||||
|     RedirectURIMatchingMode, | ||||
|     ScopeMapping, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class ApplePlatformSSOProvider(OAuth2Provider): | ||||
|     """Integrate with Apple Platform SSO""" | ||||
|  | ||||
|     def set_oauth_defaults(self): | ||||
|         """Ensure all OAuth2-related settings are correct""" | ||||
|         self.issuer_mode = IssuerMode.PER_PROVIDER | ||||
|         self.client_type = ClientTypes.PUBLIC | ||||
|         self.signing_key = CertificateKeyPair.objects.get(name="authentik Self-signed Certificate") | ||||
|         self.include_claims_in_id_token = True | ||||
|         scopes = ScopeMapping.objects.filter( | ||||
|             managed__in=[ | ||||
|                 "goauthentik.io/providers/oauth2/scope-openid", | ||||
|                 "goauthentik.io/providers/oauth2/scope-profile", | ||||
|                 "goauthentik.io/providers/oauth2/scope-email", | ||||
|                 "goauthentik.io/providers/oauth2/scope-offline_access", | ||||
|                 "goauthentik.io/providers/oauth2/scope-authentik_api", | ||||
|             ] | ||||
|         ) | ||||
|         self.property_mappings.add(*list(scopes)) | ||||
|         self.redirect_uris = [ | ||||
|             RedirectURI(RedirectURIMatchingMode.STRICT, "io.goauthentik.endpoint:/oauth2redirect"), | ||||
|         ] | ||||
|  | ||||
|     @property | ||||
|     def component(self) -> str: | ||||
|         return "ak-provider-apple-psso-form" | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> type[Serializer]: | ||||
|         from authentik.enterprise.providers.apple_psso.api.providers import ( | ||||
|             ApplePlatformSSOProviderSerializer, | ||||
|         ) | ||||
|  | ||||
|         return ApplePlatformSSOProviderSerializer | ||||
|  | ||||
|     class Meta: | ||||
|         verbose_name = _("Apple Platform SSO Provider") | ||||
|         verbose_name_plural = _("Apple Platform SSO Providers") | ||||
|  | ||||
|  | ||||
| class AppleDevice(models.Model): | ||||
|  | ||||
|     endpoint_uuid = models.UUIDField(default=uuid4, primary_key=True) | ||||
|  | ||||
|     signing_key = models.TextField() | ||||
|     encryption_key = models.TextField() | ||||
|     key_exchange_key = models.TextField() | ||||
|     sign_key_id = models.TextField() | ||||
|     enc_key_id = models.TextField() | ||||
|     creation_time = models.DateTimeField(auto_now_add=True) | ||||
|     provider = models.ForeignKey(ApplePlatformSSOProvider, on_delete=models.CASCADE) | ||||
|     users = models.ManyToManyField(User, through="AppleDeviceUser") | ||||
|  | ||||
|  | ||||
| class AppleDeviceUser(models.Model): | ||||
|  | ||||
|     uuid = models.UUIDField(default=uuid4, primary_key=True) | ||||
|  | ||||
|     device = models.ForeignKey(AppleDevice, on_delete=models.CASCADE) | ||||
|     user = models.ForeignKey(User, on_delete=models.CASCADE) | ||||
|  | ||||
|     secure_enclave_key = models.TextField() | ||||
|     enclave_key_id = models.TextField() | ||||
|  | ||||
|  | ||||
| class AppleNonce(ExpiringModel): | ||||
|     nonce = models.TextField() | ||||
| @ -1,15 +0,0 @@ | ||||
| from django.urls import path | ||||
|  | ||||
| from authentik.enterprise.providers.apple_psso.views.nonce import NonceView | ||||
| from authentik.enterprise.providers.apple_psso.views.register import ( | ||||
|     RegisterDeviceView, | ||||
|     RegisterUserView, | ||||
| ) | ||||
| from authentik.enterprise.providers.apple_psso.views.token import TokenView | ||||
|  | ||||
| urlpatterns = [ | ||||
|     path("token/", TokenView.as_view(), name="token"), | ||||
|     path("nonce/", NonceView.as_view(), name="nonce"), | ||||
|     path("register/device/", RegisterDeviceView.as_view(), name="register-device"), | ||||
|     path("register/user/", RegisterUserView.as_view(), name="register-user"), | ||||
| ] | ||||
| @ -1,7 +0,0 @@ | ||||
| from django.urls import path | ||||
|  | ||||
| from authentik.enterprise.providers.apple_psso.views.site_association import AppleAppSiteAssociation | ||||
|  | ||||
| urlpatterns = [ | ||||
|     path(".well-known/apple-app-site-association", AppleAppSiteAssociation.as_view(), name="asa"), | ||||
| ] | ||||
| @ -1,25 +0,0 @@ | ||||
| from base64 import b64encode | ||||
| from datetime import timedelta | ||||
| from secrets import token_bytes | ||||
|  | ||||
| from django.http import HttpRequest, JsonResponse | ||||
| from django.utils.decorators import method_decorator | ||||
| from django.utils.timezone import now | ||||
| from django.views import View | ||||
| from django.views.decorators.csrf import csrf_exempt | ||||
|  | ||||
| from authentik.enterprise.providers.apple_psso.models import AppleNonce | ||||
|  | ||||
|  | ||||
| @method_decorator(csrf_exempt, name="dispatch") | ||||
| class NonceView(View): | ||||
|  | ||||
|     def post(self, request: HttpRequest, *args, **kwargs): | ||||
|         nonce = AppleNonce.objects.create( | ||||
|             nonce=b64encode(token_bytes(32)).decode(), expires=now() + timedelta(minutes=5) | ||||
|         ) | ||||
|         return JsonResponse( | ||||
|             { | ||||
|                 "Nonce": nonce.nonce, | ||||
|             } | ||||
|         ) | ||||
| @ -1,92 +0,0 @@ | ||||
| from django.shortcuts import get_object_or_404 | ||||
| from rest_framework.authentication import BaseAuthentication | ||||
| from rest_framework.fields import CharField | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.response import Response | ||||
| from rest_framework.views import APIView | ||||
|  | ||||
| from authentik.api.authentication import TokenAuthentication | ||||
| from authentik.core.api.utils import PassiveSerializer | ||||
| from authentik.core.models import User | ||||
| from authentik.enterprise.providers.apple_psso.models import ( | ||||
|     AppleDevice, | ||||
|     AppleDeviceUser, | ||||
|     ApplePlatformSSOProvider, | ||||
| ) | ||||
| from authentik.lib.generators import generate_key | ||||
|  | ||||
|  | ||||
| class DeviceRegisterAuth(BaseAuthentication): | ||||
|     def authenticate(self, request): | ||||
|         # very temporary, lol | ||||
|         return (User(), None) | ||||
|  | ||||
|  | ||||
| class RegisterDeviceView(APIView): | ||||
|  | ||||
|     class DeviceRegistration(PassiveSerializer): | ||||
|  | ||||
|         device_uuid = CharField() | ||||
|         client_id = CharField() | ||||
|         device_signing_key = CharField() | ||||
|         device_encryption_key = CharField() | ||||
|         sign_key_id = CharField() | ||||
|         enc_key_id = CharField() | ||||
|  | ||||
|     permission_classes = [] | ||||
|     pagination_class = None | ||||
|     filter_backends = [] | ||||
|     serializer_class = DeviceRegistration | ||||
|     authentication_classes = [DeviceRegisterAuth, TokenAuthentication] | ||||
|  | ||||
|     def post(self, request: Request) -> Response: | ||||
|         data = self.DeviceRegistration(data=request.data) | ||||
|         data.is_valid(raise_exception=True) | ||||
|         provider = get_object_or_404( | ||||
|             ApplePlatformSSOProvider, client_id=data.validated_data["client_id"] | ||||
|         ) | ||||
|         AppleDevice.objects.update_or_create( | ||||
|             endpoint_uuid=data.validated_data["device_uuid"], | ||||
|             defaults={ | ||||
|                 "signing_key": data.validated_data["device_signing_key"], | ||||
|                 "encryption_key": data.validated_data["device_encryption_key"], | ||||
|                 "sign_key_id": data.validated_data["sign_key_id"], | ||||
|                 "enc_key_id": data.validated_data["enc_key_id"], | ||||
|                 "key_exchange_key": generate_key(), | ||||
|                 "provider": provider, | ||||
|             }, | ||||
|         ) | ||||
|         return Response() | ||||
|  | ||||
|  | ||||
| class RegisterUserView(APIView): | ||||
|  | ||||
|     class UserRegistration(PassiveSerializer): | ||||
|  | ||||
|         device_uuid = CharField() | ||||
|         user_secure_enclave_key = CharField() | ||||
|         enclave_key_id = CharField() | ||||
|  | ||||
|     permission_classes = [] | ||||
|     pagination_class = None | ||||
|     filter_backends = [] | ||||
|     serializer_class = UserRegistration | ||||
|     authentication_classes = [TokenAuthentication] | ||||
|  | ||||
|     def post(self, request: Request) -> Response: | ||||
|         data = self.UserRegistration(data=request.data) | ||||
|         data.is_valid(raise_exception=True) | ||||
|         device = get_object_or_404(AppleDevice, endpoint_uuid=data.validated_data["device_uuid"]) | ||||
|         AppleDeviceUser.objects.update_or_create( | ||||
|             device=device, | ||||
|             user=request.user, | ||||
|             defaults={ | ||||
|                 "secure_enclave_key": data.validated_data["user_secure_enclave_key"], | ||||
|                 "enclave_key_id": data.validated_data["enclave_key_id"], | ||||
|             }, | ||||
|         ) | ||||
|         return Response( | ||||
|             { | ||||
|                 "username": request.user.username, | ||||
|             } | ||||
|         ) | ||||
| @ -1,16 +0,0 @@ | ||||
| from django.http import HttpRequest, HttpResponse, JsonResponse | ||||
| from django.views import View | ||||
|  | ||||
|  | ||||
| class AppleAppSiteAssociation(View): | ||||
|     def get(self, request: HttpRequest) -> HttpResponse: | ||||
|         return JsonResponse( | ||||
|             { | ||||
|                 "authsrv": { | ||||
|                     "apps": [ | ||||
|                         "232G855Y8N.io.goauthentik.endpoint", | ||||
|                         "232G855Y8N.io.goauthentik.endpoint.psso", | ||||
|                     ] | ||||
|                 } | ||||
|             } | ||||
|         ) | ||||
| @ -1,140 +0,0 @@ | ||||
| from datetime import timedelta | ||||
|  | ||||
| from django.http import Http404, HttpRequest, HttpResponse | ||||
| from django.utils.decorators import method_decorator | ||||
| from django.utils.timezone import now | ||||
| from django.views import View | ||||
| from django.views.decorators.csrf import csrf_exempt | ||||
| from jwt import PyJWT, decode | ||||
| from rest_framework.exceptions import ValidationError | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.core.models import AuthenticatedSession, Session, User | ||||
| from authentik.core.sessions import SessionStore | ||||
| from authentik.enterprise.providers.apple_psso.http import JWEResponse | ||||
| from authentik.enterprise.providers.apple_psso.models import ( | ||||
|     AppleDevice, | ||||
|     AppleDeviceUser, | ||||
|     AppleNonce, | ||||
|     ApplePlatformSSOProvider, | ||||
| ) | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.events.signals import SESSION_LOGIN_EVENT | ||||
| from authentik.providers.oauth2.constants import TOKEN_TYPE | ||||
| from authentik.providers.oauth2.id_token import IDToken | ||||
| from authentik.providers.oauth2.models import RefreshToken | ||||
| from authentik.root.middleware import SessionMiddleware | ||||
|  | ||||
| LOGGER = get_logger() | ||||
|  | ||||
|  | ||||
| @method_decorator(csrf_exempt, name="dispatch") | ||||
| class TokenView(View): | ||||
|  | ||||
|     device: AppleDevice | ||||
|     provider: ApplePlatformSSOProvider | ||||
|  | ||||
|     def post(self, request: HttpRequest) -> HttpResponse: | ||||
|         version = request.POST.get("platform_sso_version") | ||||
|         assertion = request.POST.get("assertion", request.POST.get("request")) | ||||
|         if not assertion: | ||||
|             return HttpResponse(status=400) | ||||
|  | ||||
|         decode_unvalidated = PyJWT().decode_complete(assertion, options={"verify_signature": False}) | ||||
|         LOGGER.debug(decode_unvalidated["header"]) | ||||
|         expected_kid = decode_unvalidated["header"]["kid"] | ||||
|  | ||||
|         self.device = AppleDevice.objects.filter(sign_key_id=expected_kid).first() | ||||
|         if not self.device: | ||||
|             raise Http404 | ||||
|         self.provider = self.device.provider | ||||
|  | ||||
|         # Properly decode the JWT with the key from the device | ||||
|         decoded = decode( | ||||
|             assertion, self.device.signing_key, algorithms=["ES256"], options={"verify_aud": False} | ||||
|         ) | ||||
|         LOGGER.debug(decoded) | ||||
|  | ||||
|         LOGGER.debug("got device", device=self.device) | ||||
|  | ||||
|         # Check that the nonce hasn't been used before | ||||
|         nonce = AppleNonce.objects.filter(nonce=decoded["request_nonce"]).first() | ||||
|         if not nonce: | ||||
|             return HttpResponse(status=400) | ||||
|         nonce.delete() | ||||
|  | ||||
|         handler_func = ( | ||||
|             f"handle_v{version}_{decode_unvalidated["header"]["typ"]}".replace("-", "_") | ||||
|             .replace("+", "_") | ||||
|             .replace(".", "_") | ||||
|         ) | ||||
|         handler = getattr(self, handler_func, None) | ||||
|         if not handler: | ||||
|             LOGGER.debug("Handler not found", handler=handler_func) | ||||
|             return HttpResponse(status=400) | ||||
|         LOGGER.debug("sending to handler", handler=handler_func) | ||||
|         return handler(decoded) | ||||
|  | ||||
|     def validate_device_user_response(self, assertion: str) -> tuple[AppleDeviceUser, dict] | None: | ||||
|         """Decode an embedded assertion and validate it by looking up the matching device user""" | ||||
|         decode_unvalidated = PyJWT().decode_complete(assertion, options={"verify_signature": False}) | ||||
|         expected_kid = decode_unvalidated["header"]["kid"] | ||||
|  | ||||
|         device_user = AppleDeviceUser.objects.filter( | ||||
|             device=self.device, enclave_key_id=expected_kid | ||||
|         ).first() | ||||
|         if not device_user: | ||||
|             return None | ||||
|         return device_user, decode( | ||||
|             assertion, | ||||
|             device_user.secure_enclave_key, | ||||
|             audience="apple-platform-sso", | ||||
|             algorithms=["ES256"], | ||||
|         ) | ||||
|  | ||||
|     def create_auth_session(self, user: User): | ||||
|         event = Event.new(EventAction.LOGIN).from_http(self.request, user=user) | ||||
|         store = SessionStore() | ||||
|         store[SESSION_LOGIN_EVENT] = event | ||||
|         store.save() | ||||
|         session = Session.objects.filter(session_key=store.session_key).first() | ||||
|         AuthenticatedSession.objects.create(session=session, user=user) | ||||
|         session = SessionMiddleware.encode_session(store.session_key, user) | ||||
|         return session | ||||
|  | ||||
|     def handle_v1_0_platformsso_login_request_jwt(self, decoded: dict): | ||||
|         user = None | ||||
|         if decoded["grant_type"] == "urn:ietf:params:oauth:grant-type:jwt-bearer": | ||||
|             # Decode and validate inner assertion | ||||
|             user, inner = self.validate_device_user_response(decoded["assertion"]) | ||||
|             if inner["nonce"] != decoded["nonce"]: | ||||
|                 LOGGER.warning("Mis-matched nonce to outer assertion") | ||||
|                 raise ValidationError("Invalid request") | ||||
|  | ||||
|         refresh_token = RefreshToken( | ||||
|             user=user.user, | ||||
|             scope=decoded["scope"], | ||||
|             expires=now() + timedelta(hours=8), | ||||
|             provider=self.provider, | ||||
|             auth_time=now(), | ||||
|             session=None, | ||||
|         ) | ||||
|         id_token = IDToken.new( | ||||
|             self.provider, | ||||
|             refresh_token, | ||||
|             self.request, | ||||
|         ) | ||||
|         id_token.nonce = decoded["nonce"] | ||||
|         refresh_token.id_token = id_token | ||||
|         refresh_token.save() | ||||
|         return JWEResponse( | ||||
|             { | ||||
|                 "refresh_token": refresh_token.token, | ||||
|                 "refresh_token_expires_in": int((refresh_token.expires - now()).total_seconds()), | ||||
|                 "id_token": refresh_token.id_token.to_jwt(self.provider), | ||||
|                 "token_type": TOKEN_TYPE, | ||||
|                 "session_key": self.create_auth_session(user.user), | ||||
|             }, | ||||
|             device=self.device, | ||||
|             apv=decoded["jwe_crypto"]["apv"], | ||||
|         ) | ||||
| @ -1,8 +1,10 @@ | ||||
| from hashlib import sha256 | ||||
|  | ||||
| from django.contrib.auth.signals import user_logged_out | ||||
| from django.db.models import Model | ||||
| from django.db.models.signals import post_delete, post_save, pre_delete | ||||
| from django.dispatch import receiver | ||||
| from django.http.request import HttpRequest | ||||
| from guardian.shortcuts import assign_perm | ||||
|  | ||||
| from authentik.core.models import ( | ||||
| @ -60,6 +62,31 @@ def ssf_providers_post_save(sender: type[Model], instance: SSFProvider, created: | ||||
|             instance.save() | ||||
|  | ||||
|  | ||||
| @receiver(user_logged_out) | ||||
| def ssf_user_logged_out_session_revoked(sender, request: HttpRequest, user: User, **_): | ||||
|     """Session revoked trigger (user logged out)""" | ||||
|     if not request.session or not request.session.session_key or not user: | ||||
|         return | ||||
|     send_ssf_event( | ||||
|         EventTypes.CAEP_SESSION_REVOKED, | ||||
|         { | ||||
|             "initiating_entity": "user", | ||||
|         }, | ||||
|         sub_id={ | ||||
|             "format": "complex", | ||||
|             "session": { | ||||
|                 "format": "opaque", | ||||
|                 "id": sha256(request.session.session_key.encode("ascii")).hexdigest(), | ||||
|             }, | ||||
|             "user": { | ||||
|                 "format": "email", | ||||
|                 "email": user.email, | ||||
|             }, | ||||
|         }, | ||||
|         request=request, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @receiver(pre_delete, sender=AuthenticatedSession) | ||||
| def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSession, **_): | ||||
|     """Session revoked trigger (users' session has been deleted) | ||||
|  | ||||
| @ -1,12 +0,0 @@ | ||||
| """Enterprise app config""" | ||||
|  | ||||
| from authentik.enterprise.apps import EnterpriseConfig | ||||
|  | ||||
|  | ||||
| class AuthentikEnterpriseSearchConfig(EnterpriseConfig): | ||||
|     """Enterprise app config""" | ||||
|  | ||||
|     name = "authentik.enterprise.search" | ||||
|     label = "authentik_search" | ||||
|     verbose_name = "authentik Enterprise.Search" | ||||
|     default = True | ||||
| @ -1,128 +0,0 @@ | ||||
| """DjangoQL search""" | ||||
|  | ||||
| from collections import OrderedDict, defaultdict | ||||
| from collections.abc import Generator | ||||
|  | ||||
| from django.db import connection | ||||
| from django.db.models import Model, Q | ||||
| from djangoql.compat import text_type | ||||
| from djangoql.schema import StrField | ||||
|  | ||||
|  | ||||
| class JSONSearchField(StrField): | ||||
|     """JSON field for DjangoQL""" | ||||
|  | ||||
|     model: Model | ||||
|  | ||||
|     def __init__(self, model=None, name=None, nullable=None, suggest_nested=True): | ||||
|         # Set this in the constructor to not clobber the type variable | ||||
|         self.type = "relation" | ||||
|         self.suggest_nested = suggest_nested | ||||
|         super().__init__(model, name, nullable) | ||||
|  | ||||
|     def get_lookup(self, path, operator, value): | ||||
|         search = "__".join(path) | ||||
|         op, invert = self.get_operator(operator) | ||||
|         q = Q(**{f"{search}{op}": self.get_lookup_value(value)}) | ||||
|         return ~q if invert else q | ||||
|  | ||||
|     def json_field_keys(self) -> Generator[tuple[str]]: | ||||
|         with connection.cursor() as cursor: | ||||
|             cursor.execute( | ||||
|                 f""" | ||||
|                 WITH RECURSIVE "{self.name}_keys" AS ( | ||||
|                     SELECT | ||||
|                         ARRAY[jsonb_object_keys("{self.name}")] AS key_path_array, | ||||
|                         "{self.name}" -> jsonb_object_keys("{self.name}") AS value | ||||
|                     FROM {self.model._meta.db_table} | ||||
|                     WHERE "{self.name}" IS NOT NULL | ||||
|                         AND jsonb_typeof("{self.name}") = 'object' | ||||
|  | ||||
|                     UNION ALL | ||||
|  | ||||
|                     SELECT | ||||
|                         ck.key_path_array || jsonb_object_keys(ck.value), | ||||
|                         ck.value -> jsonb_object_keys(ck.value) AS value | ||||
|                     FROM "{self.name}_keys" ck | ||||
|                     WHERE jsonb_typeof(ck.value) = 'object' | ||||
|                 ), | ||||
|  | ||||
|                 unique_paths AS ( | ||||
|                     SELECT DISTINCT key_path_array | ||||
|                     FROM "{self.name}_keys" | ||||
|                 ) | ||||
|  | ||||
|                 SELECT key_path_array FROM unique_paths; | ||||
|             """  # nosec | ||||
|             ) | ||||
|             return (x[0] for x in cursor.fetchall()) | ||||
|  | ||||
|     def get_nested_options(self) -> OrderedDict: | ||||
|         """Get keys of all nested objects to show autocomplete""" | ||||
|         if not self.suggest_nested: | ||||
|             return OrderedDict() | ||||
|         base_model_name = f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}" | ||||
|  | ||||
|         def recursive_function(parts: list[str], parent_parts: list[str] | None = None): | ||||
|             if not parent_parts: | ||||
|                 parent_parts = [] | ||||
|             path = parts.pop(0) | ||||
|             parent_parts.append(path) | ||||
|             relation_key = "_".join(parent_parts) | ||||
|             if len(parts) > 1: | ||||
|                 out_dict = { | ||||
|                     relation_key: { | ||||
|                         parts[0]: { | ||||
|                             "type": "relation", | ||||
|                             "relation": f"{relation_key}_{parts[0]}", | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|                 child_paths = recursive_function(parts.copy(), parent_parts.copy()) | ||||
|                 child_paths.update(out_dict) | ||||
|                 return child_paths | ||||
|             else: | ||||
|                 return {relation_key: {parts[0]: {}}} | ||||
|  | ||||
|         relation_structure = defaultdict(dict) | ||||
|  | ||||
|         for relations in self.json_field_keys(): | ||||
|             result = recursive_function([base_model_name] + relations) | ||||
|             for relation_key, value in result.items(): | ||||
|                 for sub_relation_key, sub_value in value.items(): | ||||
|                     if not relation_structure[relation_key].get(sub_relation_key, None): | ||||
|                         relation_structure[relation_key][sub_relation_key] = sub_value | ||||
|                     else: | ||||
|                         relation_structure[relation_key][sub_relation_key].update(sub_value) | ||||
|  | ||||
|         final_dict = defaultdict(dict) | ||||
|  | ||||
|         for key, value in relation_structure.items(): | ||||
|             for sub_key, sub_value in value.items(): | ||||
|                 if not sub_value: | ||||
|                     final_dict[key][sub_key] = { | ||||
|                         "type": "str", | ||||
|                         "nullable": True, | ||||
|                     } | ||||
|                 else: | ||||
|                     final_dict[key][sub_key] = sub_value | ||||
|         return OrderedDict(final_dict) | ||||
|  | ||||
|     def relation(self) -> str: | ||||
|         return f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}" | ||||
|  | ||||
|  | ||||
| class ChoiceSearchField(StrField): | ||||
|     def __init__(self, model=None, name=None, nullable=None): | ||||
|         super().__init__(model, name, nullable, suggest_options=True) | ||||
|  | ||||
|     def get_options(self, search): | ||||
|         result = [] | ||||
|         choices = self._field_choices() | ||||
|         if choices: | ||||
|             search = search.lower() | ||||
|             for c in choices: | ||||
|                 choice = text_type(c[0]) | ||||
|                 if search in choice.lower(): | ||||
|                     result.append(choice) | ||||
|         return result | ||||
| @ -1,53 +0,0 @@ | ||||
| from rest_framework.response import Response | ||||
|  | ||||
| from authentik.api.pagination import Pagination | ||||
| from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, QLSearch | ||||
|  | ||||
|  | ||||
| class AutocompletePagination(Pagination): | ||||
|  | ||||
|     def paginate_queryset(self, queryset, request, view=None): | ||||
|         self.view = view | ||||
|         return super().paginate_queryset(queryset, request, view) | ||||
|  | ||||
|     def get_autocomplete(self): | ||||
|         schema = QLSearch().get_schema(self.request, self.view) | ||||
|         introspections = {} | ||||
|         if hasattr(self.view, "get_ql_fields"): | ||||
|             from authentik.enterprise.search.schema import AKQLSchemaSerializer | ||||
|  | ||||
|             introspections = AKQLSchemaSerializer().serialize( | ||||
|                 schema(self.page.paginator.object_list.model) | ||||
|             ) | ||||
|         return introspections | ||||
|  | ||||
|     def get_paginated_response(self, data): | ||||
|         previous_page_number = 0 | ||||
|         if self.page.has_previous(): | ||||
|             previous_page_number = self.page.previous_page_number() | ||||
|         next_page_number = 0 | ||||
|         if self.page.has_next(): | ||||
|             next_page_number = self.page.next_page_number() | ||||
|         return Response( | ||||
|             { | ||||
|                 "pagination": { | ||||
|                     "next": next_page_number, | ||||
|                     "previous": previous_page_number, | ||||
|                     "count": self.page.paginator.count, | ||||
|                     "current": self.page.number, | ||||
|                     "total_pages": self.page.paginator.num_pages, | ||||
|                     "start_index": self.page.start_index(), | ||||
|                     "end_index": self.page.end_index(), | ||||
|                 }, | ||||
|                 "results": data, | ||||
|                 "autocomplete": self.get_autocomplete(), | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|     def get_paginated_response_schema(self, schema): | ||||
|         final_schema = super().get_paginated_response_schema(schema) | ||||
|         final_schema["properties"]["autocomplete"] = { | ||||
|             "$ref": f"#/components/schemas/{AUTOCOMPLETE_COMPONENT_NAME}" | ||||
|         } | ||||
|         final_schema["required"].append("autocomplete") | ||||
|         return final_schema | ||||
| @ -1,78 +0,0 @@ | ||||
| """DjangoQL search""" | ||||
|  | ||||
| from django.apps import apps | ||||
| from django.db.models import QuerySet | ||||
| from djangoql.ast import Name | ||||
| from djangoql.exceptions import DjangoQLError | ||||
| from djangoql.queryset import apply_search | ||||
| from djangoql.schema import DjangoQLSchema | ||||
| from rest_framework.filters import SearchFilter | ||||
| from rest_framework.request import Request | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.enterprise.search.fields import JSONSearchField | ||||
|  | ||||
| LOGGER = get_logger() | ||||
| AUTOCOMPLETE_COMPONENT_NAME = "Autocomplete" | ||||
| AUTOCOMPLETE_SCHEMA = { | ||||
|     "type": "object", | ||||
|     "additionalProperties": {}, | ||||
| } | ||||
|  | ||||
|  | ||||
| class BaseSchema(DjangoQLSchema): | ||||
|     """Base Schema which deals with JSON Fields""" | ||||
|  | ||||
|     def resolve_name(self, name: Name): | ||||
|         model = self.model_label(self.current_model) | ||||
|         root_field = name.parts[0] | ||||
|         field = self.models[model].get(root_field) | ||||
|         # If the query goes into a JSON field, return the root | ||||
|         # field as the JSON field will do the rest | ||||
|         if isinstance(field, JSONSearchField): | ||||
|             # This is a workaround; build_filter will remove the right-most | ||||
|             # entry in the path as that is intended to be the same as the field | ||||
|             # however for JSON that is not the case | ||||
|             if name.parts[-1] != root_field: | ||||
|                 name.parts.append(root_field) | ||||
|             return field | ||||
|         return super().resolve_name(name) | ||||
|  | ||||
|  | ||||
| class QLSearch(SearchFilter): | ||||
|     """rest_framework search filter which uses DjangoQL""" | ||||
|  | ||||
|     @property | ||||
|     def enabled(self): | ||||
|         return apps.get_app_config("authentik_enterprise").enabled() | ||||
|  | ||||
|     def get_search_terms(self, request) -> str: | ||||
|         """ | ||||
|         Search terms are set by a ?search=... query parameter, | ||||
|         and may be comma and/or whitespace delimited. | ||||
|         """ | ||||
|         params = request.query_params.get(self.search_param, "") | ||||
|         params = params.replace("\x00", "")  # strip null characters | ||||
|         return params | ||||
|  | ||||
|     def get_schema(self, request: Request, view) -> BaseSchema: | ||||
|         ql_fields = [] | ||||
|         if hasattr(view, "get_ql_fields"): | ||||
|             ql_fields = view.get_ql_fields() | ||||
|  | ||||
|         class InlineSchema(BaseSchema): | ||||
|             def get_fields(self, model): | ||||
|                 return ql_fields or [] | ||||
|  | ||||
|         return InlineSchema | ||||
|  | ||||
|     def filter_queryset(self, request: Request, queryset: QuerySet, view) -> QuerySet: | ||||
|         search_query = self.get_search_terms(request) | ||||
|         schema = self.get_schema(request, view) | ||||
|         if len(search_query) == 0 or not self.enabled: | ||||
|             return super().filter_queryset(request, queryset, view) | ||||
|         try: | ||||
|             return apply_search(queryset, search_query, schema=schema) | ||||
|         except DjangoQLError as exc: | ||||
|             LOGGER.debug("Failed to parse search expression", exc=exc) | ||||
|             return super().filter_queryset(request, queryset, view) | ||||
| @ -1,29 +0,0 @@ | ||||
| from djangoql.serializers import DjangoQLSchemaSerializer | ||||
| from drf_spectacular.generators import SchemaGenerator | ||||
|  | ||||
| from authentik.api.schema import create_component | ||||
| from authentik.enterprise.search.fields import JSONSearchField | ||||
| from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA | ||||
|  | ||||
|  | ||||
| class AKQLSchemaSerializer(DjangoQLSchemaSerializer): | ||||
|     def serialize(self, schema): | ||||
|         serialization = super().serialize(schema) | ||||
|         for _, fields in schema.models.items(): | ||||
|             for _, field in fields.items(): | ||||
|                 if not isinstance(field, JSONSearchField): | ||||
|                     continue | ||||
|                 serialization["models"].update(field.get_nested_options()) | ||||
|         return serialization | ||||
|  | ||||
|     def serialize_field(self, field): | ||||
|         result = super().serialize_field(field) | ||||
|         if isinstance(field, JSONSearchField): | ||||
|             result["relation"] = field.relation() | ||||
|         return result | ||||
|  | ||||
|  | ||||
| def postprocess_schema_search_autocomplete(result, generator: SchemaGenerator, **kwargs): | ||||
|     create_component(generator, AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA) | ||||
|  | ||||
|     return result | ||||
| @ -1,17 +0,0 @@ | ||||
| SPECTACULAR_SETTINGS = { | ||||
|     "POSTPROCESSING_HOOKS": [ | ||||
|         "authentik.api.schema.postprocess_schema_responses", | ||||
|         "authentik.enterprise.search.schema.postprocess_schema_search_autocomplete", | ||||
|         "drf_spectacular.hooks.postprocess_schema_enums", | ||||
|     ], | ||||
| } | ||||
|  | ||||
| REST_FRAMEWORK = { | ||||
|     "DEFAULT_PAGINATION_CLASS": "authentik.enterprise.search.pagination.AutocompletePagination", | ||||
|     "DEFAULT_FILTER_BACKENDS": [ | ||||
|         "authentik.enterprise.search.ql.QLSearch", | ||||
|         "authentik.rbac.filters.ObjectFilter", | ||||
|         "django_filters.rest_framework.DjangoFilterBackend", | ||||
|         "rest_framework.filters.OrderingFilter", | ||||
|     ], | ||||
| } | ||||
| @ -1,78 +0,0 @@ | ||||
| from json import loads | ||||
| from unittest.mock import PropertyMock, patch | ||||
| from urllib.parse import urlencode | ||||
|  | ||||
| from django.urls import reverse | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.tests.utils import create_test_admin_user | ||||
|  | ||||
|  | ||||
| @patch( | ||||
|     "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", | ||||
|     PropertyMock(return_value=True), | ||||
| ) | ||||
| class QLTest(APITestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         self.user = create_test_admin_user() | ||||
|         # ensure we have more than 1 user | ||||
|         create_test_admin_user() | ||||
|  | ||||
|     def test_search(self): | ||||
|         """Test simple search query""" | ||||
|         self.client.force_login(self.user) | ||||
|         query = f'username = "{self.user.username}"' | ||||
|         res = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_api:user-list", | ||||
|             ) | ||||
|             + f"?{urlencode({"search": query})}" | ||||
|         ) | ||||
|         self.assertEqual(res.status_code, 200) | ||||
|         content = loads(res.content) | ||||
|         self.assertEqual(content["pagination"]["count"], 1) | ||||
|         self.assertEqual(content["results"][0]["username"], self.user.username) | ||||
|  | ||||
|     def test_no_search(self): | ||||
|         """Ensure works with no search query""" | ||||
|         self.client.force_login(self.user) | ||||
|         res = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_api:user-list", | ||||
|             ) | ||||
|         ) | ||||
|         self.assertEqual(res.status_code, 200) | ||||
|         content = loads(res.content) | ||||
|         self.assertNotEqual(content["pagination"]["count"], 1) | ||||
|  | ||||
|     def test_search_no_ql(self): | ||||
|         """Test simple search query (no QL)""" | ||||
|         self.client.force_login(self.user) | ||||
|         res = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_api:user-list", | ||||
|             ) | ||||
|             + f"?{urlencode({"search": self.user.username})}" | ||||
|         ) | ||||
|         self.assertEqual(res.status_code, 200) | ||||
|         content = loads(res.content) | ||||
|         self.assertGreaterEqual(content["pagination"]["count"], 1) | ||||
|         self.assertEqual(content["results"][0]["username"], self.user.username) | ||||
|  | ||||
|     def test_search_json(self): | ||||
|         """Test search query with a JSON attribute""" | ||||
|         self.user.attributes = {"foo": {"bar": "baz"}} | ||||
|         self.user.save() | ||||
|         self.client.force_login(self.user) | ||||
|         query = 'attributes.foo.bar = "baz"' | ||||
|         res = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_api:user-list", | ||||
|             ) | ||||
|             + f"?{urlencode({"search": query})}" | ||||
|         ) | ||||
|         self.assertEqual(res.status_code, 200) | ||||
|         content = loads(res.content) | ||||
|         self.assertEqual(content["pagination"]["count"], 1) | ||||
|         self.assertEqual(content["results"][0]["username"], self.user.username) | ||||
| @ -15,11 +15,9 @@ CELERY_BEAT_SCHEDULE = { | ||||
| TENANT_APPS = [ | ||||
|     "authentik.enterprise.audit", | ||||
|     "authentik.enterprise.policies.unique_password", | ||||
|     "authentik.enterprise.providers.apple_psso", | ||||
|     "authentik.enterprise.providers.google_workspace", | ||||
|     "authentik.enterprise.providers.microsoft_entra", | ||||
|     "authentik.enterprise.providers.ssf", | ||||
|     "authentik.enterprise.search", | ||||
|     "authentik.enterprise.stages.authenticator_endpoint_gdtc", | ||||
|     "authentik.enterprise.stages.mtls", | ||||
|     "authentik.enterprise.stages.source", | ||||
|  | ||||
| @ -97,7 +97,6 @@ class SourceStageFinal(StageView): | ||||
|         token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN) | ||||
|         self.logger.info("Replacing source flow with overridden flow", flow=token.flow.slug) | ||||
|         plan = token.plan | ||||
|         plan.context.update(self.executor.plan.context) | ||||
|         plan.context[PLAN_CONTEXT_IS_RESTORED] = token | ||||
|         response = plan.to_redirect(self.request, token.flow) | ||||
|         token.delete() | ||||
|  | ||||
| @ -90,17 +90,14 @@ class TestSourceStage(FlowTestCase): | ||||
|         plan: FlowPlan = session[SESSION_KEY_PLAN] | ||||
|         plan.insert_stage(in_memory_stage(SourceStageFinal), index=0) | ||||
|         plan.context[PLAN_CONTEXT_IS_RESTORED] = flow_token | ||||
|         plan.context["foo"] = "bar" | ||||
|         session[SESSION_KEY_PLAN] = plan | ||||
|         session.save() | ||||
|  | ||||
|         # Pretend we've just returned from the source | ||||
|         with self.assertFlowFinishes() as ff: | ||||
|             response = self.client.get( | ||||
|                 reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True | ||||
|             ) | ||||
|             self.assertEqual(response.status_code, 200) | ||||
|             self.assertStageRedirects( | ||||
|                 response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug}) | ||||
|             ) | ||||
|         self.assertEqual(ff().context["foo"], "bar") | ||||
|         response = self.client.get( | ||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertStageRedirects( | ||||
|             response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug}) | ||||
|         ) | ||||
|  | ||||
| @ -132,22 +132,6 @@ class EventViewSet(ModelViewSet): | ||||
|     ] | ||||
|     filterset_class = EventsFilter | ||||
|  | ||||
|     def get_ql_fields(self): | ||||
|         from djangoql.schema import DateTimeField, StrField | ||||
|  | ||||
|         from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField | ||||
|  | ||||
|         return [ | ||||
|             ChoiceSearchField(Event, "action"), | ||||
|             StrField(Event, "event_uuid"), | ||||
|             StrField(Event, "app", suggest_options=True), | ||||
|             StrField(Event, "client_ip"), | ||||
|             JSONSearchField(Event, "user", suggest_nested=False), | ||||
|             JSONSearchField(Event, "brand", suggest_nested=False), | ||||
|             JSONSearchField(Event, "context", suggest_nested=False), | ||||
|             DateTimeField(Event, "created", suggest_options=True), | ||||
|         ] | ||||
|  | ||||
|     @extend_schema( | ||||
|         methods=["GET"], | ||||
|         responses={200: EventTopPerUserSerializer(many=True)}, | ||||
|  | ||||
| @ -11,7 +11,7 @@ from authentik.events.models import NotificationRule | ||||
| class NotificationRuleSerializer(ModelSerializer): | ||||
|     """NotificationRule Serializer""" | ||||
|  | ||||
|     destination_group_obj = GroupSerializer(read_only=True, source="destination_group") | ||||
|     group_obj = GroupSerializer(read_only=True, source="group") | ||||
|  | ||||
|     class Meta: | ||||
|         model = NotificationRule | ||||
| @ -20,9 +20,8 @@ class NotificationRuleSerializer(ModelSerializer): | ||||
|             "name", | ||||
|             "transports", | ||||
|             "severity", | ||||
|             "destination_group", | ||||
|             "destination_group_obj", | ||||
|             "destination_event_user", | ||||
|             "group", | ||||
|             "group_obj", | ||||
|         ] | ||||
|  | ||||
|  | ||||
| @ -31,6 +30,6 @@ class NotificationRuleViewSet(UsedByMixin, ModelViewSet): | ||||
|  | ||||
|     queryset = NotificationRule.objects.all() | ||||
|     serializer_class = NotificationRuleSerializer | ||||
|     filterset_fields = ["name", "severity", "destination_group__name"] | ||||
|     filterset_fields = ["name", "severity", "group__name"] | ||||
|     ordering = ["name"] | ||||
|     search_fields = ["name", "destination_group__name"] | ||||
|     search_fields = ["name", "group__name"] | ||||
|  | ||||
| @ -15,13 +15,13 @@ class MMDBContextProcessor(EventContextProcessor): | ||||
|         self.reader: Reader | None = None | ||||
|         self._last_mtime: float = 0.0 | ||||
|         self.logger = get_logger() | ||||
|         self.load() | ||||
|         self.open() | ||||
|  | ||||
|     def path(self) -> str | None: | ||||
|         """Get the path to the MMDB file to load""" | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def load(self): | ||||
|     def open(self): | ||||
|         """Get GeoIP Reader, if configured, otherwise none""" | ||||
|         path = self.path() | ||||
|         if path == "" or not path: | ||||
| @ -44,7 +44,7 @@ class MMDBContextProcessor(EventContextProcessor): | ||||
|             diff = self._last_mtime < mtime | ||||
|             if diff > 0: | ||||
|                 self.logger.info("Found new MMDB Database, reopening", diff=diff, path=path) | ||||
|                 self.load() | ||||
|                 self.open() | ||||
|         except OSError as exc: | ||||
|             self.logger.warning("Failed to check MMDB age", exc=exc) | ||||
|  | ||||
|  | ||||
| @ -19,7 +19,7 @@ from authentik.blueprints.v1.importer import excluded_models | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.events.models import Event, EventAction, Notification | ||||
| from authentik.events.utils import model_to_dict | ||||
| from authentik.lib.sentry import should_ignore_exception | ||||
| from authentik.lib.sentry import before_send | ||||
| from authentik.lib.utils.errors import exception_to_string | ||||
| from authentik.stages.authenticator_static.models import StaticToken | ||||
|  | ||||
| @ -173,7 +173,7 @@ class AuditMiddleware: | ||||
|                 message=exception_to_string(exception), | ||||
|             ) | ||||
|             thread.run() | ||||
|         elif not should_ignore_exception(exception): | ||||
|         elif before_send({}, {"exc_info": (None, exception, None)}) is not None: | ||||
|             thread = EventNewThread( | ||||
|                 EventAction.SYSTEM_EXCEPTION, | ||||
|                 request, | ||||
|  | ||||
| @ -1,26 +0,0 @@ | ||||
| # Generated by Django 5.1.11 on 2025-06-16 23:21 | ||||
|  | ||||
| from django.db import migrations, models | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("authentik_events", "0009_remove_notificationtransport_webhook_mapping_and_more"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.RenameField( | ||||
|             model_name="notificationrule", | ||||
|             old_name="group", | ||||
|             new_name="destination_group", | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="notificationrule", | ||||
|             name="destination_event_user", | ||||
|             field=models.BooleanField( | ||||
|                 default=False, | ||||
|                 help_text="When enabled, notification will be sent to user the user that triggered the event.When destination_group is configured, notification is sent to both.", | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
| @ -1,12 +1,10 @@ | ||||
| """authentik events models""" | ||||
|  | ||||
| from collections.abc import Generator | ||||
| from datetime import timedelta | ||||
| from difflib import get_close_matches | ||||
| from functools import lru_cache | ||||
| from inspect import currentframe | ||||
| from smtplib import SMTPException | ||||
| from typing import Any | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from django.apps import apps | ||||
| @ -193,32 +191,17 @@ class Event(SerializerModel, ExpiringModel): | ||||
|             brand: Brand = request.brand | ||||
|             self.brand = sanitize_dict(model_to_dict(brand)) | ||||
|         if hasattr(request, "user"): | ||||
|             self.user = get_user(request.user) | ||||
|             original_user = None | ||||
|             if hasattr(request, "session"): | ||||
|                 original_user = request.session.get(SESSION_KEY_IMPERSONATE_ORIGINAL_USER, None) | ||||
|             self.user = get_user(request.user, original_user) | ||||
|         if user: | ||||
|             self.user = get_user(user) | ||||
|         # Check if we're currently impersonating, and add that user | ||||
|         if hasattr(request, "session"): | ||||
|             from authentik.flows.views.executor import SESSION_KEY_PLAN | ||||
|  | ||||
|             # Check if we're currently impersonating, and add that user | ||||
|             if SESSION_KEY_IMPERSONATE_ORIGINAL_USER in request.session: | ||||
|                 self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER]) | ||||
|                 self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER]) | ||||
|             # Special case for events that happen during a flow, the user might not be authenticated | ||||
|             # yet but is a pending user instead | ||||
|             if SESSION_KEY_PLAN in request.session: | ||||
|                 from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan | ||||
|  | ||||
|                 plan: FlowPlan = request.session[SESSION_KEY_PLAN] | ||||
|                 pending_user = plan.context.get(PLAN_CONTEXT_PENDING_USER, None) | ||||
|                 # Only save `authenticated_as` if there's a different pending user in the flow | ||||
|                 # than the user that is authenticated | ||||
|                 if pending_user and ( | ||||
|                     (pending_user.pk and pending_user.pk != self.user.get("pk")) | ||||
|                     or (not pending_user.pk) | ||||
|                 ): | ||||
|                     orig_user = self.user.copy() | ||||
|  | ||||
|                     self.user = {"authenticated_as": orig_user, **get_user(pending_user)} | ||||
|         # User 255.255.255.255 as fallback if IP cannot be determined | ||||
|         self.client_ip = ClientIPMiddleware.get_client_ip(request) | ||||
|         # Enrich event data | ||||
| @ -564,7 +547,7 @@ class NotificationRule(SerializerModel, PolicyBindingModel): | ||||
|         default=NotificationSeverity.NOTICE, | ||||
|         help_text=_("Controls which severity level the created notifications will have."), | ||||
|     ) | ||||
|     destination_group = models.ForeignKey( | ||||
|     group = models.ForeignKey( | ||||
|         Group, | ||||
|         help_text=_( | ||||
|             "Define which group of users this notification should be sent and shown to. " | ||||
| @ -574,19 +557,6 @@ class NotificationRule(SerializerModel, PolicyBindingModel): | ||||
|         blank=True, | ||||
|         on_delete=models.SET_NULL, | ||||
|     ) | ||||
|     destination_event_user = models.BooleanField( | ||||
|         default=False, | ||||
|         help_text=_( | ||||
|             "When enabled, notification will be sent to user the user that triggered the event." | ||||
|             "When destination_group is configured, notification is sent to both." | ||||
|         ), | ||||
|     ) | ||||
|  | ||||
|     def destination_users(self, event: Event) -> Generator[User, Any]: | ||||
|         if self.destination_event_user and event.user.get("pk"): | ||||
|             yield User(pk=event.user.get("pk")) | ||||
|         if self.destination_group: | ||||
|             yield from self.destination_group.users.all() | ||||
|  | ||||
|     @property | ||||
|     def serializer(self) -> type[Serializer]: | ||||
|  | ||||
| @ -68,10 +68,14 @@ def event_trigger_handler(event_uuid: str, trigger_name: str): | ||||
|     if not result.passing: | ||||
|         return | ||||
|  | ||||
|     if not trigger.group: | ||||
|         LOGGER.debug("e(trigger): trigger has no group", trigger=trigger) | ||||
|         return | ||||
|  | ||||
|     LOGGER.debug("e(trigger): event trigger matched", trigger=trigger) | ||||
|     # Create the notification objects | ||||
|     for transport in trigger.transports.all(): | ||||
|         for user in trigger.destination_users(event): | ||||
|         for user in trigger.group.users.all(): | ||||
|             LOGGER.debug("created notification") | ||||
|             notification_transport.apply_async( | ||||
|                 args=[ | ||||
|  | ||||
| @ -2,9 +2,7 @@ | ||||
|  | ||||
| from django.test import TestCase | ||||
|  | ||||
| from authentik.events.context_processors.base import get_context_processors | ||||
| from authentik.events.context_processors.geoip import GeoIPContextProcessor | ||||
| from authentik.events.models import Event, EventAction | ||||
|  | ||||
|  | ||||
| class TestGeoIP(TestCase): | ||||
| @ -15,7 +13,8 @@ class TestGeoIP(TestCase): | ||||
|  | ||||
|     def test_simple(self): | ||||
|         """Test simple city wrapper""" | ||||
|         # IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json | ||||
|         # IPs from | ||||
|         # https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json | ||||
|         self.assertEqual( | ||||
|             self.reader.city_dict("2.125.160.216"), | ||||
|             { | ||||
| @ -26,12 +25,3 @@ class TestGeoIP(TestCase): | ||||
|                 "long": -1.25, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_special_chars(self): | ||||
|         """Test city name with special characters""" | ||||
|         # IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json | ||||
|         event = Event.new(EventAction.LOGIN) | ||||
|         event.client_ip = "89.160.20.112" | ||||
|         for processor in get_context_processors(): | ||||
|             processor.enrich_event(event) | ||||
|         event.save() | ||||
|  | ||||
| @ -8,11 +8,9 @@ from django.views.debug import SafeExceptionReporterFilter | ||||
| from guardian.shortcuts import get_anonymous_user | ||||
|  | ||||
| from authentik.brands.models import Brand | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.core.tests.utils import create_test_user | ||||
| from authentik.core.models import Group | ||||
| from authentik.events.models import Event | ||||
| from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan | ||||
| from authentik.flows.views.executor import QS_QUERY, SESSION_KEY_PLAN | ||||
| from authentik.flows.views.executor import QS_QUERY | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.policies.dummy.models import DummyPolicy | ||||
|  | ||||
| @ -118,92 +116,3 @@ class TestEvents(TestCase): | ||||
|                 "pk": brand.pk.hex, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_from_http_flow_pending_user(self): | ||||
|         """Test request from flow request with a pending user""" | ||||
|         user = create_test_user() | ||||
|  | ||||
|         session = self.client.session | ||||
|         plan = FlowPlan(generate_id()) | ||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = user | ||||
|         session[SESSION_KEY_PLAN] = plan | ||||
|         session.save() | ||||
|  | ||||
|         request = self.factory.get("/") | ||||
|         request.session = session | ||||
|         request.user = user | ||||
|  | ||||
|         event = Event.new("unittest").from_http(request) | ||||
|         self.assertEqual( | ||||
|             event.user, | ||||
|             { | ||||
|                 "email": user.email, | ||||
|                 "pk": user.pk, | ||||
|                 "username": user.username, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_from_http_flow_pending_user_anon(self): | ||||
|         """Test request from flow request with a pending user""" | ||||
|         user = create_test_user() | ||||
|         anon = get_anonymous_user() | ||||
|  | ||||
|         session = self.client.session | ||||
|         plan = FlowPlan(generate_id()) | ||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = user | ||||
|         session[SESSION_KEY_PLAN] = plan | ||||
|         session.save() | ||||
|  | ||||
|         request = self.factory.get("/") | ||||
|         request.session = session | ||||
|         request.user = anon | ||||
|  | ||||
|         event = Event.new("unittest").from_http(request) | ||||
|         self.assertEqual( | ||||
|             event.user, | ||||
|             { | ||||
|                 "authenticated_as": { | ||||
|                     "pk": anon.pk, | ||||
|                     "is_anonymous": True, | ||||
|                     "username": "AnonymousUser", | ||||
|                     "email": "", | ||||
|                 }, | ||||
|                 "email": user.email, | ||||
|                 "pk": user.pk, | ||||
|                 "username": user.username, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_from_http_flow_pending_user_fake(self): | ||||
|         """Test request from flow request with a pending user""" | ||||
|         user = User( | ||||
|             username=generate_id(), | ||||
|             email=generate_id(), | ||||
|         ) | ||||
|         anon = get_anonymous_user() | ||||
|  | ||||
|         session = self.client.session | ||||
|         plan = FlowPlan(generate_id()) | ||||
|         plan.context[PLAN_CONTEXT_PENDING_USER] = user | ||||
|         session[SESSION_KEY_PLAN] = plan | ||||
|         session.save() | ||||
|  | ||||
|         request = self.factory.get("/") | ||||
|         request.session = session | ||||
|         request.user = anon | ||||
|  | ||||
|         event = Event.new("unittest").from_http(request) | ||||
|         self.assertEqual( | ||||
|             event.user, | ||||
|             { | ||||
|                 "authenticated_as": { | ||||
|                     "pk": anon.pk, | ||||
|                     "is_anonymous": True, | ||||
|                     "username": "AnonymousUser", | ||||
|                     "email": "", | ||||
|                 }, | ||||
|                 "email": user.email, | ||||
|                 "pk": user.pk, | ||||
|                 "username": user.username, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
| @ -6,7 +6,6 @@ from django.urls import reverse | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.core.tests.utils import create_test_user | ||||
| from authentik.events.models import ( | ||||
|     Event, | ||||
|     EventAction, | ||||
| @ -35,7 +34,7 @@ class TestEventsNotifications(APITestCase): | ||||
|     def test_trigger_empty(self): | ||||
|         """Test trigger without any policies attached""" | ||||
|         transport = NotificationTransport.objects.create(name=generate_id()) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||
|         trigger.transports.add(transport) | ||||
|         trigger.save() | ||||
|  | ||||
| @ -47,7 +46,7 @@ class TestEventsNotifications(APITestCase): | ||||
|     def test_trigger_single(self): | ||||
|         """Test simple transport triggering""" | ||||
|         transport = NotificationTransport.objects.create(name=generate_id()) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||
|         trigger.transports.add(transport) | ||||
|         trigger.save() | ||||
|         matcher = EventMatcherPolicy.objects.create( | ||||
| @ -60,25 +59,6 @@ class TestEventsNotifications(APITestCase): | ||||
|             Event.new(EventAction.CUSTOM_PREFIX).save() | ||||
|         self.assertEqual(execute_mock.call_count, 1) | ||||
|  | ||||
|     def test_trigger_event_user(self): | ||||
|         """Test trigger with event user""" | ||||
|         user = create_test_user() | ||||
|         transport = NotificationTransport.objects.create(name=generate_id()) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_event_user=True) | ||||
|         trigger.transports.add(transport) | ||||
|         trigger.save() | ||||
|         matcher = EventMatcherPolicy.objects.create( | ||||
|             name="matcher", action=EventAction.CUSTOM_PREFIX | ||||
|         ) | ||||
|         PolicyBinding.objects.create(target=trigger, policy=matcher, order=0) | ||||
|  | ||||
|         execute_mock = MagicMock() | ||||
|         with patch("authentik.events.models.NotificationTransport.send", execute_mock): | ||||
|             Event.new(EventAction.CUSTOM_PREFIX).set_user(user).save() | ||||
|         self.assertEqual(execute_mock.call_count, 1) | ||||
|         notification: Notification = execute_mock.call_args[0][0] | ||||
|         self.assertEqual(notification.user, user) | ||||
|  | ||||
|     def test_trigger_no_group(self): | ||||
|         """Test trigger without group""" | ||||
|         trigger = NotificationRule.objects.create(name=generate_id()) | ||||
| @ -96,7 +76,7 @@ class TestEventsNotifications(APITestCase): | ||||
|         """Test Policy error which would cause recursion""" | ||||
|         transport = NotificationTransport.objects.create(name=generate_id()) | ||||
|         NotificationRule.objects.filter(name__startswith="default").delete() | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||
|         trigger.transports.add(transport) | ||||
|         trigger.save() | ||||
|         matcher = EventMatcherPolicy.objects.create( | ||||
| @ -119,7 +99,7 @@ class TestEventsNotifications(APITestCase): | ||||
|  | ||||
|         transport = NotificationTransport.objects.create(name=generate_id(), send_once=True) | ||||
|         NotificationRule.objects.filter(name__startswith="default").delete() | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||
|         trigger.transports.add(transport) | ||||
|         trigger.save() | ||||
|         matcher = EventMatcherPolicy.objects.create( | ||||
| @ -143,7 +123,7 @@ class TestEventsNotifications(APITestCase): | ||||
|             name=generate_id(), webhook_mapping_body=mapping, mode=TransportMode.LOCAL | ||||
|         ) | ||||
|         NotificationRule.objects.filter(name__startswith="default").delete() | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) | ||||
|         trigger = NotificationRule.objects.create(name=generate_id(), group=self.group) | ||||
|         trigger.transports.add(transport) | ||||
|         matcher = EventMatcherPolicy.objects.create( | ||||
|             name="matcher", action=EventAction.CUSTOM_PREFIX | ||||
|  | ||||
| @ -74,8 +74,8 @@ def model_to_dict(model: Model) -> dict[str, Any]: | ||||
|     } | ||||
|  | ||||
|  | ||||
| def get_user(user: User | AnonymousUser) -> dict[str, Any]: | ||||
|     """Convert user object to dictionary""" | ||||
| def get_user(user: User | AnonymousUser, original_user: User | None = None) -> dict[str, Any]: | ||||
|     """Convert user object to dictionary, optionally including the original user""" | ||||
|     if isinstance(user, AnonymousUser): | ||||
|         try: | ||||
|             user = get_anonymous_user() | ||||
| @ -88,6 +88,10 @@ def get_user(user: User | AnonymousUser) -> dict[str, Any]: | ||||
|     } | ||||
|     if user.username == settings.ANONYMOUS_USER_NAME: | ||||
|         user_data["is_anonymous"] = True | ||||
|     if original_user: | ||||
|         original_data = get_user(original_user) | ||||
|         original_data["on_behalf_of"] = user_data | ||||
|         return original_data | ||||
|     return user_data | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -4,10 +4,8 @@ from unittest.mock import MagicMock, PropertyMock, patch | ||||
| from urllib.parse import urlencode | ||||
|  | ||||
| from django.http import HttpRequest, HttpResponse | ||||
| from django.test import override_settings | ||||
| from django.test.client import RequestFactory | ||||
| from django.urls import reverse | ||||
| from rest_framework.exceptions import ParseError | ||||
|  | ||||
| from authentik.core.models import Group, User | ||||
| from authentik.core.tests.utils import create_test_flow, create_test_user | ||||
| @ -650,25 +648,3 @@ class TestFlowExecutor(FlowTestCase): | ||||
|             self.assertStageResponse(response, flow, component="ak-stage-identification") | ||||
|             response = self.client.post(exec_url, {"uid_field": user_other.username}, follow=True) | ||||
|             self.assertStageResponse(response, flow, component="ak-stage-access-denied") | ||||
|  | ||||
|     @patch( | ||||
|         "authentik.flows.views.executor.to_stage_response", | ||||
|         TO_STAGE_RESPONSE_MOCK, | ||||
|     ) | ||||
|     def test_invalid_json(self): | ||||
|         """Test invalid JSON body""" | ||||
|         flow = create_test_flow() | ||||
|         FlowStageBinding.objects.create( | ||||
|             target=flow, stage=DummyStage.objects.create(name=generate_id()), order=0 | ||||
|         ) | ||||
|         url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) | ||||
|  | ||||
|         with override_settings(TEST=False, DEBUG=False): | ||||
|             self.client.logout() | ||||
|             response = self.client.post(url, data="{", content_type="application/json") | ||||
|             self.assertEqual(response.status_code, 200) | ||||
|  | ||||
|         with self.assertRaises(ParseError): | ||||
|             self.client.logout() | ||||
|             response = self.client.post(url, data="{", content_type="application/json") | ||||
|             self.assertEqual(response.status_code, 200) | ||||
|  | ||||
| @ -55,7 +55,7 @@ from authentik.flows.planner import ( | ||||
|     FlowPlanner, | ||||
| ) | ||||
| from authentik.flows.stage import AccessDeniedStage, StageView | ||||
| from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception | ||||
| from authentik.lib.sentry import SentryIgnoredException | ||||
| from authentik.lib.utils.errors import exception_to_string | ||||
| from authentik.lib.utils.reflection import all_subclasses, class_to_path | ||||
| from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs | ||||
| @ -234,13 +234,12 @@ class FlowExecutorView(APIView): | ||||
|         """Handle exception in stage execution""" | ||||
|         if settings.DEBUG or settings.TEST: | ||||
|             raise exc | ||||
|         capture_exception(exc) | ||||
|         self._logger.warning(exc) | ||||
|         if not should_ignore_exception(exc): | ||||
|             capture_exception(exc) | ||||
|             Event.new( | ||||
|                 action=EventAction.SYSTEM_EXCEPTION, | ||||
|                 message=exception_to_string(exc), | ||||
|             ).from_http(self.request) | ||||
|         Event.new( | ||||
|             action=EventAction.SYSTEM_EXCEPTION, | ||||
|             message=exception_to_string(exc), | ||||
|         ).from_http(self.request) | ||||
|         challenge = FlowErrorChallenge(self.request, exc) | ||||
|         challenge.is_valid(raise_exception=True) | ||||
|         return to_stage_response(self.request, HttpChallengeResponse(challenge)) | ||||
|  | ||||
| @ -14,7 +14,6 @@ from django_redis.exceptions import ConnectionInterrupted | ||||
| from docker.errors import DockerException | ||||
| from h11 import LocalProtocolError | ||||
| from ldap3.core.exceptions import LDAPException | ||||
| from psycopg.errors import Error | ||||
| from redis.exceptions import ConnectionError as RedisConnectionError | ||||
| from redis.exceptions import RedisError, ResponseError | ||||
| from rest_framework.exceptions import APIException | ||||
| @ -45,49 +44,6 @@ class SentryIgnoredException(Exception): | ||||
|     """Base Class for all errors that are suppressed, and not sent to sentry.""" | ||||
|  | ||||
|  | ||||
| ignored_classes = ( | ||||
|     # Inbuilt types | ||||
|     KeyboardInterrupt, | ||||
|     ConnectionResetError, | ||||
|     OSError, | ||||
|     PermissionError, | ||||
|     # Django Errors | ||||
|     Error, | ||||
|     ImproperlyConfigured, | ||||
|     DatabaseError, | ||||
|     OperationalError, | ||||
|     InternalError, | ||||
|     ProgrammingError, | ||||
|     SuspiciousOperation, | ||||
|     ValidationError, | ||||
|     # Redis errors | ||||
|     RedisConnectionError, | ||||
|     ConnectionInterrupted, | ||||
|     RedisError, | ||||
|     ResponseError, | ||||
|     # websocket errors | ||||
|     ChannelFull, | ||||
|     WebSocketException, | ||||
|     LocalProtocolError, | ||||
|     # rest_framework error | ||||
|     APIException, | ||||
|     # celery errors | ||||
|     WorkerLostError, | ||||
|     CeleryError, | ||||
|     SoftTimeLimitExceeded, | ||||
|     # custom baseclass | ||||
|     SentryIgnoredException, | ||||
|     # ldap errors | ||||
|     LDAPException, | ||||
|     # Docker errors | ||||
|     DockerException, | ||||
|     # End-user errors | ||||
|     Http404, | ||||
|     # AsyncIO | ||||
|     CancelledError, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class SentryTransport(HttpTransport): | ||||
|     """Custom sentry transport with custom user-agent""" | ||||
|  | ||||
| @ -145,17 +101,56 @@ def traces_sampler(sampling_context: dict) -> float: | ||||
|     return float(CONFIG.get("error_reporting.sample_rate", 0.1)) | ||||
|  | ||||
|  | ||||
| def should_ignore_exception(exc: Exception) -> bool: | ||||
|     """Check if an exception should be dropped""" | ||||
|     return isinstance(exc, ignored_classes) | ||||
|  | ||||
|  | ||||
| def before_send(event: dict, hint: dict) -> dict | None: | ||||
|     """Check if error is database error, and ignore if so""" | ||||
|  | ||||
|     from psycopg.errors import Error | ||||
|  | ||||
|     ignored_classes = ( | ||||
|         # Inbuilt types | ||||
|         KeyboardInterrupt, | ||||
|         ConnectionResetError, | ||||
|         OSError, | ||||
|         PermissionError, | ||||
|         # Django Errors | ||||
|         Error, | ||||
|         ImproperlyConfigured, | ||||
|         DatabaseError, | ||||
|         OperationalError, | ||||
|         InternalError, | ||||
|         ProgrammingError, | ||||
|         SuspiciousOperation, | ||||
|         ValidationError, | ||||
|         # Redis errors | ||||
|         RedisConnectionError, | ||||
|         ConnectionInterrupted, | ||||
|         RedisError, | ||||
|         ResponseError, | ||||
|         # websocket errors | ||||
|         ChannelFull, | ||||
|         WebSocketException, | ||||
|         LocalProtocolError, | ||||
|         # rest_framework error | ||||
|         APIException, | ||||
|         # celery errors | ||||
|         WorkerLostError, | ||||
|         CeleryError, | ||||
|         SoftTimeLimitExceeded, | ||||
|         # custom baseclass | ||||
|         SentryIgnoredException, | ||||
|         # ldap errors | ||||
|         LDAPException, | ||||
|         # Docker errors | ||||
|         DockerException, | ||||
|         # End-user errors | ||||
|         Http404, | ||||
|         # AsyncIO | ||||
|         CancelledError, | ||||
|     ) | ||||
|     exc_value = None | ||||
|     if "exc_info" in hint: | ||||
|         _, exc_value, _ = hint["exc_info"] | ||||
|         if should_ignore_exception(exc_value): | ||||
|         if isinstance(exc_value, ignored_classes): | ||||
|             LOGGER.debug("dropping exception", exc=exc_value) | ||||
|             return None | ||||
|     if "logger" in event: | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
|  | ||||
| from django.test import TestCase | ||||
|  | ||||
| from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception | ||||
| from authentik.lib.sentry import SentryIgnoredException, before_send | ||||
|  | ||||
|  | ||||
| class TestSentry(TestCase): | ||||
| @ -10,8 +10,8 @@ class TestSentry(TestCase): | ||||
|  | ||||
|     def test_error_not_sent(self): | ||||
|         """Test SentryIgnoredError not sent""" | ||||
|         self.assertTrue(should_ignore_exception(SentryIgnoredException())) | ||||
|         self.assertIsNone(before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)})) | ||||
|  | ||||
|     def test_error_sent(self): | ||||
|         """Test error sent""" | ||||
|         self.assertFalse(should_ignore_exception(ValueError())) | ||||
|         self.assertEqual({}, before_send({}, {"exc_info": (0, ValueError(), 0)})) | ||||
|  | ||||
| @ -1,13 +1,15 @@ | ||||
| """authentik outpost signals""" | ||||
|  | ||||
| from django.contrib.auth.signals import user_logged_out | ||||
| from django.core.cache import cache | ||||
| from django.db.models import Model | ||||
| from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save | ||||
| from django.dispatch import receiver | ||||
| from django.http import HttpRequest | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.brands.models import Brand | ||||
| from authentik.core.models import AuthenticatedSession, Provider | ||||
| from authentik.core.models import AuthenticatedSession, Provider, User | ||||
| from authentik.crypto.models import CertificateKeyPair | ||||
| from authentik.lib.utils.reflection import class_to_path | ||||
| from authentik.outposts.models import Outpost, OutpostServiceConnection | ||||
| @ -80,6 +82,14 @@ def pre_delete_cleanup(sender, instance: Outpost, **_): | ||||
|     outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) | ||||
|  | ||||
|  | ||||
| @receiver(user_logged_out) | ||||
| def logout_revoke_direct(sender: type[User], request: HttpRequest, **_): | ||||
|     """Catch logout by direct logout and forward to providers""" | ||||
|     if not request.session or not request.session.session_key: | ||||
|         return | ||||
|     outpost_session_end.delay(request.session.session_key) | ||||
|  | ||||
|  | ||||
| @receiver(pre_delete, sender=AuthenticatedSession) | ||||
| def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): | ||||
|     """Catch logout by expiring sessions being deleted""" | ||||
|  | ||||
| @ -15,6 +15,7 @@ class OAuth2Error(SentryIgnoredException): | ||||
|  | ||||
|     error: str | ||||
|     description: str | ||||
|     cause: str | None = None | ||||
|  | ||||
|     def create_dict(self): | ||||
|         """Return error as dict for JSON Rendering""" | ||||
| @ -34,6 +35,10 @@ class OAuth2Error(SentryIgnoredException): | ||||
|             **kwargs, | ||||
|         ) | ||||
|  | ||||
|     def with_cause(self, cause: str): | ||||
|         self.cause = cause | ||||
|         return self | ||||
|  | ||||
|  | ||||
| class RedirectUriError(OAuth2Error): | ||||
|     """The request fails due to a missing, invalid, or mismatching | ||||
|  | ||||
| @ -1,10 +1,23 @@ | ||||
| from django.contrib.auth.signals import user_logged_out | ||||
| from django.db.models.signals import post_save, pre_delete | ||||
| from django.dispatch import receiver | ||||
| from django.http import HttpRequest | ||||
|  | ||||
| from authentik.core.models import AuthenticatedSession, User | ||||
| from authentik.providers.oauth2.models import AccessToken, DeviceToken, RefreshToken | ||||
|  | ||||
|  | ||||
| @receiver(user_logged_out) | ||||
| def user_logged_out_oauth_tokens_removal(sender, request: HttpRequest, user: User, **_): | ||||
|     """Revoke tokens upon user logout""" | ||||
|     if not request.session or not request.session.session_key: | ||||
|         return | ||||
|     AccessToken.objects.filter( | ||||
|         user=user, | ||||
|         session__session__session_key=request.session.session_key, | ||||
|     ).delete() | ||||
|  | ||||
|  | ||||
| @receiver(pre_delete, sender=AuthenticatedSession) | ||||
| def user_session_deleted_oauth_tokens_removal(sender, instance: AuthenticatedSession, **_): | ||||
|     """Revoke tokens upon user logout""" | ||||
|  | ||||
| @ -12,7 +12,7 @@ from authentik.core.tests.utils import create_test_admin_user, create_test_flow | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.lib.utils.time import timedelta_from_string | ||||
| from authentik.providers.oauth2.constants import TOKEN_TYPE | ||||
| from authentik.providers.oauth2.constants import SCOPE_OFFLINE_ACCESS, SCOPE_OPENID, TOKEN_TYPE | ||||
| from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError | ||||
| from authentik.providers.oauth2.models import ( | ||||
|     AccessToken, | ||||
| @ -43,7 +43,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")], | ||||
|         ) | ||||
|         with self.assertRaises(AuthorizeError): | ||||
|         with self.assertRaises(AuthorizeError) as cm: | ||||
|             request = self.factory.get( | ||||
|                 "/", | ||||
|                 data={ | ||||
| @ -53,6 +53,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|                 }, | ||||
|             ) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|         self.assertEqual(cm.exception.error, "unsupported_response_type") | ||||
|  | ||||
|     def test_invalid_client_id(self): | ||||
|         """Test invalid client ID""" | ||||
| @ -68,7 +69,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")], | ||||
|         ) | ||||
|         with self.assertRaises(AuthorizeError): | ||||
|         with self.assertRaises(AuthorizeError) as cm: | ||||
|             request = self.factory.get( | ||||
|                 "/", | ||||
|                 data={ | ||||
| @ -79,19 +80,30 @@ class TestAuthorize(OAuthTestCase): | ||||
|                 }, | ||||
|             ) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|         self.assertEqual(cm.exception.error, "request_not_supported") | ||||
|  | ||||
|     def test_invalid_redirect_uri(self): | ||||
|         """test missing/invalid redirect URI""" | ||||
|     def test_invalid_redirect_uri_missing(self): | ||||
|         """test missing redirect URI""" | ||||
|         OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||
|         ) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|         with self.assertRaises(RedirectUriError) as cm: | ||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|         self.assertEqual(cm.exception.cause, "redirect_uri_missing") | ||||
|  | ||||
|     def test_invalid_redirect_uri(self): | ||||
|         """test invalid redirect URI""" | ||||
|         OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], | ||||
|         ) | ||||
|         with self.assertRaises(RedirectUriError) as cm: | ||||
|             request = self.factory.get( | ||||
|                 "/", | ||||
|                 data={ | ||||
| @ -101,6 +113,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|                 }, | ||||
|             ) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|         self.assertEqual(cm.exception.cause, "redirect_uri_no_match") | ||||
|  | ||||
|     def test_blocked_redirect_uri(self): | ||||
|         """test missing/invalid redirect URI""" | ||||
| @ -108,9 +121,9 @@ class TestAuthorize(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:local.invalid")], | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:localhost")], | ||||
|         ) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|         with self.assertRaises(RedirectUriError) as cm: | ||||
|             request = self.factory.get( | ||||
|                 "/", | ||||
|                 data={ | ||||
| @ -120,6 +133,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|                 }, | ||||
|             ) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|         self.assertEqual(cm.exception.cause, "redirect_uri_forbidden_scheme") | ||||
|  | ||||
|     def test_invalid_redirect_uri_empty(self): | ||||
|         """test missing/invalid redirect URI""" | ||||
| @ -129,9 +143,6 @@ class TestAuthorize(OAuthTestCase): | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris=[], | ||||
|         ) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|         request = self.factory.get( | ||||
|             "/", | ||||
|             data={ | ||||
| @ -150,12 +161,9 @@ class TestAuthorize(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid?")], | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "http://local.invalid?")], | ||||
|         ) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|         with self.assertRaises(RedirectUriError) as cm: | ||||
|             request = self.factory.get( | ||||
|                 "/", | ||||
|                 data={ | ||||
| @ -165,6 +173,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|                 }, | ||||
|             ) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|         self.assertEqual(cm.exception.cause, "redirect_uri_no_match") | ||||
|  | ||||
|     def test_redirect_uri_invalid_regex(self): | ||||
|         """test missing/invalid redirect URI (invalid regex)""" | ||||
| @ -172,12 +181,9 @@ class TestAuthorize(OAuthTestCase): | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "+")], | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "+")], | ||||
|         ) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|         with self.assertRaises(RedirectUriError) as cm: | ||||
|             request = self.factory.get( | ||||
|                 "/", | ||||
|                 data={ | ||||
| @ -187,23 +193,22 @@ class TestAuthorize(OAuthTestCase): | ||||
|                 }, | ||||
|             ) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|         self.assertEqual(cm.exception.cause, "redirect_uri_no_match") | ||||
|  | ||||
|     def test_empty_redirect_uri(self): | ||||
|         """test empty redirect URI (configure in provider)""" | ||||
|     def test_redirect_uri_regex(self): | ||||
|         """test valid redirect URI (regex)""" | ||||
|         OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, ".+")], | ||||
|         ) | ||||
|         with self.assertRaises(RedirectUriError): | ||||
|             request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|         request = self.factory.get( | ||||
|             "/", | ||||
|             data={ | ||||
|                 "response_type": "code", | ||||
|                 "client_id": "test", | ||||
|                 "redirect_uri": "http://localhost", | ||||
|                 "redirect_uri": "http://foo.bar.baz", | ||||
|             }, | ||||
|         ) | ||||
|         OAuthAuthorizationParams.from_request(request) | ||||
| @ -258,7 +263,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|             GrantTypes.IMPLICIT, | ||||
|         ) | ||||
|         # Implicit without openid scope | ||||
|         with self.assertRaises(AuthorizeError): | ||||
|         with self.assertRaises(AuthorizeError) as cm: | ||||
|             request = self.factory.get( | ||||
|                 "/", | ||||
|                 data={ | ||||
| @ -285,7 +290,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|         self.assertEqual( | ||||
|             OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID | ||||
|         ) | ||||
|         with self.assertRaises(AuthorizeError): | ||||
|         with self.assertRaises(AuthorizeError) as cm: | ||||
|             request = self.factory.get( | ||||
|                 "/", | ||||
|                 data={ | ||||
| @ -295,6 +300,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|                 }, | ||||
|             ) | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|         self.assertEqual(cm.exception.error, "unsupported_response_type") | ||||
|  | ||||
|     def test_full_code(self): | ||||
|         """Test full authorization""" | ||||
| @ -387,7 +393,8 @@ class TestAuthorize(OAuthTestCase): | ||||
|             self.assertEqual( | ||||
|                 response.url, | ||||
|                 ( | ||||
|                     f"http://localhost#id_token={provider.encode(token.id_token.to_dict())}" | ||||
|                     f"http://localhost#access_token={token.token}" | ||||
|                     f"&id_token={provider.encode(token.id_token.to_dict())}" | ||||
|                     f"&token_type={TOKEN_TYPE}" | ||||
|                     f"&expires_in={int(expires)}&state={state}" | ||||
|                 ), | ||||
| @ -562,6 +569,7 @@ class TestAuthorize(OAuthTestCase): | ||||
|                 "url": "http://localhost", | ||||
|                 "title": f"Redirecting to {app.name}...", | ||||
|                 "attrs": { | ||||
|                     "access_token": token.token, | ||||
|                     "id_token": provider.encode(token.id_token.to_dict()), | ||||
|                     "token_type": TOKEN_TYPE, | ||||
|                     "expires_in": "3600", | ||||
| @ -613,3 +621,54 @@ class TestAuthorize(OAuthTestCase): | ||||
|                 }, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_openid_missing_invalid(self): | ||||
|         """test request requiring an OpenID scope to be set""" | ||||
|         OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")], | ||||
|         ) | ||||
|         request = self.factory.get( | ||||
|             "/", | ||||
|             data={ | ||||
|                 "response_type": "id_token", | ||||
|                 "client_id": "test", | ||||
|                 "redirect_uri": "http://localhost", | ||||
|                 "scope": "", | ||||
|             }, | ||||
|         ) | ||||
|         with self.assertRaises(AuthorizeError) as cm: | ||||
|             OAuthAuthorizationParams.from_request(request) | ||||
|         self.assertEqual(cm.exception.cause, "scope_openid_missing") | ||||
|  | ||||
|     @apply_blueprint("system/providers-oauth2.yaml") | ||||
|     def test_offline_access_invalid(self): | ||||
|         """test request for offline_access with invalid response type""" | ||||
|         provider = OAuth2Provider.objects.create( | ||||
|             name=generate_id(), | ||||
|             client_id="test", | ||||
|             authorization_flow=create_test_flow(), | ||||
|             redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")], | ||||
|         ) | ||||
|         provider.property_mappings.set( | ||||
|             ScopeMapping.objects.filter( | ||||
|                 managed__in=[ | ||||
|                     "goauthentik.io/providers/oauth2/scope-openid", | ||||
|                     "goauthentik.io/providers/oauth2/scope-offline_access", | ||||
|                 ] | ||||
|             ) | ||||
|         ) | ||||
|         request = self.factory.get( | ||||
|             "/", | ||||
|             data={ | ||||
|                 "response_type": "id_token", | ||||
|                 "client_id": "test", | ||||
|                 "redirect_uri": "http://localhost", | ||||
|                 "scope": f"{SCOPE_OPENID} {SCOPE_OFFLINE_ACCESS}", | ||||
|                 "nonce": generate_id(), | ||||
|             }, | ||||
|         ) | ||||
|         parsed = OAuthAuthorizationParams.from_request(request) | ||||
|         self.assertNotIn(SCOPE_OFFLINE_ACCESS, parsed.scope) | ||||
|  | ||||
| @ -150,12 +150,12 @@ class OAuthAuthorizationParams: | ||||
|         self.check_redirect_uri() | ||||
|         self.check_grant() | ||||
|         self.check_scope(github_compat) | ||||
|         self.check_nonce() | ||||
|         self.check_code_challenge() | ||||
|         if self.request: | ||||
|             raise AuthorizeError( | ||||
|                 self.redirect_uri, "request_not_supported", self.grant_type, self.state | ||||
|             ) | ||||
|         self.check_nonce() | ||||
|         self.check_code_challenge() | ||||
|  | ||||
|     def check_grant(self): | ||||
|         """Check grant""" | ||||
| @ -190,7 +190,7 @@ class OAuthAuthorizationParams: | ||||
|         allowed_redirect_urls = self.provider.redirect_uris | ||||
|         if not self.redirect_uri: | ||||
|             LOGGER.warning("Missing redirect uri.") | ||||
|             raise RedirectUriError("", allowed_redirect_urls) | ||||
|             raise RedirectUriError("", allowed_redirect_urls).with_cause("redirect_uri_missing") | ||||
|  | ||||
|         if len(allowed_redirect_urls) < 1: | ||||
|             LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri) | ||||
| @ -219,10 +219,14 @@ class OAuthAuthorizationParams: | ||||
|                         provider=self.provider, | ||||
|                     ) | ||||
|         if not match_found: | ||||
|             raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) | ||||
|             raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause( | ||||
|                 "redirect_uri_no_match" | ||||
|             ) | ||||
|         # Check against forbidden schemes | ||||
|         if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES: | ||||
|             raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) | ||||
|             raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause( | ||||
|                 "redirect_uri_forbidden_scheme" | ||||
|             ) | ||||
|  | ||||
|     def check_scope(self, github_compat=False): | ||||
|         """Ensure openid scope is set in Hybrid flows, or when requesting an id_token""" | ||||
| @ -251,7 +255,9 @@ class OAuthAuthorizationParams: | ||||
|             or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN] | ||||
|         ): | ||||
|             LOGGER.warning("Missing 'openid' scope.") | ||||
|             raise AuthorizeError(self.redirect_uri, "invalid_scope", self.grant_type, self.state) | ||||
|             raise AuthorizeError( | ||||
|                 self.redirect_uri, "invalid_scope", self.grant_type, self.state | ||||
|             ).with_cause("scope_openid_missing") | ||||
|         if SCOPE_OFFLINE_ACCESS in self.scope: | ||||
|             # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess | ||||
|             # Don't explicitly request consent with offline_access, as the spec allows for | ||||
| @ -286,7 +292,9 @@ class OAuthAuthorizationParams: | ||||
|             return | ||||
|         if not self.nonce: | ||||
|             LOGGER.warning("Missing nonce for OpenID Request") | ||||
|             raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state) | ||||
|             raise AuthorizeError( | ||||
|                 self.redirect_uri, "invalid_request", self.grant_type, self.state | ||||
|             ).with_cause("none_missing") | ||||
|  | ||||
|     def check_code_challenge(self): | ||||
|         """PKCE validation of the transformation method.""" | ||||
| @ -345,10 +353,10 @@ class AuthorizationFlowInitView(PolicyAccessView): | ||||
|                 self.request, github_compat=self.github_compat | ||||
|             ) | ||||
|         except AuthorizeError as error: | ||||
|             LOGGER.warning(error.description, redirect_uri=error.redirect_uri) | ||||
|             LOGGER.warning(error.description, redirect_uri=error.redirect_uri, cause=error.cause) | ||||
|             raise RequestValidationError(error.get_response(self.request)) from None | ||||
|         except OAuth2Error as error: | ||||
|             LOGGER.warning(error.description) | ||||
|             LOGGER.warning(error.description, cause=error.cause) | ||||
|             raise RequestValidationError( | ||||
|                 bad_request_message(self.request, error.description, title=error.error) | ||||
|             ) from None | ||||
| @ -630,6 +638,7 @@ class OAuthFulfillmentStage(StageView): | ||||
|         if self.params.response_type in [ | ||||
|             ResponseTypes.ID_TOKEN_TOKEN, | ||||
|             ResponseTypes.CODE_ID_TOKEN_TOKEN, | ||||
|             ResponseTypes.ID_TOKEN, | ||||
|             ResponseTypes.CODE_TOKEN, | ||||
|         ]: | ||||
|             query_fragment["access_token"] = token.token | ||||
|  | ||||
| @ -555,8 +555,6 @@ class TokenView(View): | ||||
|  | ||||
|     provider: OAuth2Provider | None = None | ||||
|     params: TokenParams | None = None | ||||
|     params_class = TokenParams | ||||
|     provider_class = OAuth2Provider | ||||
|  | ||||
|     def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: | ||||
|         response = super().dispatch(request, *args, **kwargs) | ||||
| @ -576,14 +574,12 @@ class TokenView(View): | ||||
|                 op="authentik.providers.oauth2.post.parse", | ||||
|             ): | ||||
|                 client_id, client_secret = extract_client_auth(request) | ||||
|                 self.provider = self.provider_class.objects.filter(client_id=client_id).first() | ||||
|                 self.provider = OAuth2Provider.objects.filter(client_id=client_id).first() | ||||
|                 if not self.provider: | ||||
|                     LOGGER.warning("OAuth2Provider does not exist", client_id=client_id) | ||||
|                     raise TokenError("invalid_client") | ||||
|                 CTX_AUTH_VIA.set("oauth_client_secret") | ||||
|                 self.params = self.params_class.parse( | ||||
|                     request, self.provider, client_id, client_secret | ||||
|                 ) | ||||
|                 self.params = TokenParams.parse(request, self.provider, client_id, client_secret) | ||||
|  | ||||
|             with start_span( | ||||
|                 op="authentik.providers.oauth2.post.response", | ||||
|  | ||||
| @ -66,10 +66,7 @@ class RACClientConsumer(AsyncWebsocketConsumer): | ||||
|     def init_outpost_connection(self): | ||||
|         """Initialize guac connection settings""" | ||||
|         self.token = ( | ||||
|             ConnectionToken.filter_not_expired( | ||||
|                 token=self.scope["url_route"]["kwargs"]["token"], | ||||
|                 session__session__session_key=self.scope["session"].session_key, | ||||
|             ) | ||||
|             ConnectionToken.filter_not_expired(token=self.scope["url_route"]["kwargs"]["token"]) | ||||
|             .select_related("endpoint", "provider", "session", "session__user") | ||||
|             .first() | ||||
|         ) | ||||
|  | ||||
| @ -2,11 +2,13 @@ | ||||
|  | ||||
| from asgiref.sync import async_to_sync | ||||
| from channels.layers import get_channel_layer | ||||
| from django.contrib.auth.signals import user_logged_out | ||||
| from django.core.cache import cache | ||||
| from django.db.models.signals import post_delete, post_save, pre_delete | ||||
| from django.dispatch import receiver | ||||
| from django.http import HttpRequest | ||||
|  | ||||
| from authentik.core.models import AuthenticatedSession | ||||
| from authentik.core.models import AuthenticatedSession, User | ||||
| from authentik.providers.rac.api.endpoints import user_endpoint_cache_key | ||||
| from authentik.providers.rac.consumer_client import ( | ||||
|     RAC_CLIENT_GROUP_SESSION, | ||||
| @ -15,6 +17,21 @@ from authentik.providers.rac.consumer_client import ( | ||||
| from authentik.providers.rac.models import ConnectionToken, Endpoint | ||||
|  | ||||
|  | ||||
| @receiver(user_logged_out) | ||||
| def user_logged_out_session(sender, request: HttpRequest, user: User, **_): | ||||
|     """Disconnect any open RAC connections""" | ||||
|     if not request.session or not request.session.session_key: | ||||
|         return | ||||
|     layer = get_channel_layer() | ||||
|     async_to_sync(layer.group_send)( | ||||
|         RAC_CLIENT_GROUP_SESSION | ||||
|         % { | ||||
|             "session": request.session.session_key, | ||||
|         }, | ||||
|         {"type": "event.disconnect", "reason": "session_logout"}, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @receiver(pre_delete, sender=AuthenticatedSession) | ||||
| def user_session_deleted(sender, instance: AuthenticatedSession, **_): | ||||
|     layer = get_channel_layer() | ||||
|  | ||||
| @ -49,7 +49,6 @@ class TestEndpointsAPI(APITestCase): | ||||
|         self.assertJSONEqual( | ||||
|             response.content.decode(), | ||||
|             { | ||||
|                 "autocomplete": {}, | ||||
|                 "pagination": { | ||||
|                     "next": 0, | ||||
|                     "previous": 0, | ||||
| @ -102,7 +101,6 @@ class TestEndpointsAPI(APITestCase): | ||||
|         self.assertJSONEqual( | ||||
|             response.content.decode(), | ||||
|             { | ||||
|                 "autocomplete": {}, | ||||
|                 "pagination": { | ||||
|                     "next": 0, | ||||
|                     "previous": 0, | ||||
|  | ||||
| @ -87,22 +87,3 @@ class TestRACViews(APITestCase): | ||||
|         ) | ||||
|         body = loads(flow_response.content) | ||||
|         self.assertEqual(body["component"], "ak-stage-access-denied") | ||||
|  | ||||
|     def test_different_session(self): | ||||
|         """Test request""" | ||||
|         self.client.force_login(self.user) | ||||
|         response = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_providers_rac:start", | ||||
|                 kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)}, | ||||
|             ) | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 302) | ||||
|         flow_response = self.client.get( | ||||
|             reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) | ||||
|         ) | ||||
|         body = loads(flow_response.content) | ||||
|         next_url = body["to"] | ||||
|         self.client.logout() | ||||
|         final_response = self.client.get(next_url) | ||||
|         self.assertEqual(final_response.url, reverse("authentik_core:if-user")) | ||||
|  | ||||
| @ -68,10 +68,7 @@ class RACInterface(InterfaceView): | ||||
|  | ||||
|     def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: | ||||
|         # Early sanity check to ensure token still exists | ||||
|         token = ConnectionToken.filter_not_expired( | ||||
|             token=self.kwargs["token"], | ||||
|             session__session__session_key=request.session.session_key, | ||||
|         ).first() | ||||
|         token = ConnectionToken.filter_not_expired(token=self.kwargs["token"]).first() | ||||
|         if not token: | ||||
|             return redirect("authentik_core:if-user") | ||||
|         self.token = token | ||||
|  | ||||
| @ -5,6 +5,7 @@ from itertools import batched | ||||
| from django.db import transaction | ||||
| from pydantic import ValidationError | ||||
| from pydanticscim.group import GroupMember | ||||
| from pydanticscim.responses import PatchOp | ||||
|  | ||||
| from authentik.core.models import Group | ||||
| from authentik.lib.sync.mapper import PropertyMappingManager | ||||
| @ -19,12 +20,7 @@ from authentik.providers.scim.clients.base import SCIMClient | ||||
| from authentik.providers.scim.clients.exceptions import ( | ||||
|     SCIMRequestException, | ||||
| ) | ||||
| from authentik.providers.scim.clients.schema import ( | ||||
|     SCIM_GROUP_SCHEMA, | ||||
|     PatchOp, | ||||
|     PatchOperation, | ||||
|     PatchRequest, | ||||
| ) | ||||
| from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest | ||||
| from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema | ||||
| from authentik.providers.scim.models import ( | ||||
|     SCIMMapping, | ||||
|  | ||||
| @ -1,7 +1,5 @@ | ||||
| """Custom SCIM schemas""" | ||||
|  | ||||
| from enum import Enum | ||||
|  | ||||
| from pydantic import Field | ||||
| from pydanticscim.group import Group as BaseGroup | ||||
| from pydanticscim.responses import PatchOperation as BasePatchOperation | ||||
| @ -67,21 +65,6 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class PatchOp(str, Enum): | ||||
|  | ||||
|     replace = "replace" | ||||
|     remove = "remove" | ||||
|     add = "add" | ||||
|  | ||||
|     @classmethod | ||||
|     def _missing_(cls, value): | ||||
|         value = value.lower() | ||||
|         for member in cls: | ||||
|             if member.lower() == value: | ||||
|                 return member | ||||
|         return None | ||||
|  | ||||
|  | ||||
| class PatchRequest(BasePatchRequest): | ||||
|     """PatchRequest which correctly sets schemas""" | ||||
|  | ||||
| @ -91,7 +74,6 @@ class PatchRequest(BasePatchRequest): | ||||
| class PatchOperation(BasePatchOperation): | ||||
|     """PatchOperation with optional path""" | ||||
|  | ||||
|     op: PatchOp | ||||
|     path: str | None | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -44,7 +44,6 @@ class TestRBACRoleAPI(APITestCase): | ||||
|         self.assertJSONEqual( | ||||
|             res.content.decode(), | ||||
|             { | ||||
|                 "autocomplete": {}, | ||||
|                 "pagination": { | ||||
|                     "next": 0, | ||||
|                     "previous": 0, | ||||
|  | ||||
| @ -46,7 +46,6 @@ class TestRBACUserAPI(APITestCase): | ||||
|         self.assertJSONEqual( | ||||
|             res.content.decode(), | ||||
|             { | ||||
|                 "autocomplete": {}, | ||||
|                 "pagination": { | ||||
|                     "next": 0, | ||||
|                     "previous": 0, | ||||
|  | ||||
| @ -38,7 +38,6 @@ class TestAPIPerms(APITestCase): | ||||
|         self.assertJSONEqual( | ||||
|             res.content.decode(), | ||||
|             { | ||||
|                 "autocomplete": {}, | ||||
|                 "pagination": { | ||||
|                     "next": 0, | ||||
|                     "previous": 0, | ||||
| @ -74,7 +73,6 @@ class TestAPIPerms(APITestCase): | ||||
|         self.assertJSONEqual( | ||||
|             res.content.decode(), | ||||
|             { | ||||
|                 "autocomplete": {}, | ||||
|                 "pagination": { | ||||
|                     "next": 0, | ||||
|                     "previous": 0, | ||||
|  | ||||
| @ -9,14 +9,13 @@ https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/ | ||||
|  | ||||
| import django | ||||
| from channels.routing import ProtocolTypeRouter, URLRouter | ||||
| from defusedxml import defuse_stdlib | ||||
| from django.core.asgi import get_asgi_application | ||||
| from sentry_sdk.integrations.asgi import SentryAsgiMiddleware | ||||
|  | ||||
| from authentik.root.setup import setup | ||||
|  | ||||
| # DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py | ||||
|  | ||||
| setup() | ||||
| defuse_stdlib() | ||||
| django.setup() | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -27,7 +27,7 @@ from structlog.stdlib import get_logger | ||||
| from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp | ||||
|  | ||||
| from authentik import get_full_version | ||||
| from authentik.lib.sentry import should_ignore_exception | ||||
| from authentik.lib.sentry import before_send | ||||
| from authentik.lib.utils.errors import exception_to_string | ||||
|  | ||||
| # set the default Django settings module for the 'celery' program. | ||||
| @ -81,7 +81,7 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar | ||||
|  | ||||
|     LOGGER.warning("Task failure", task_id=task_id.replace("-", ""), exc=exception) | ||||
|     CTX_TASK_ID.set(...) | ||||
|     if not should_ignore_exception(exception): | ||||
|     if before_send({}, {"exc_info": (None, exception, None)}) is not None: | ||||
|         Event.new( | ||||
|             EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id | ||||
|         ).save() | ||||
|  | ||||
| @ -1,49 +1,13 @@ | ||||
| """authentik database backend""" | ||||
|  | ||||
| from django.core.checks import Warning | ||||
| from django.db.backends.base.validation import BaseDatabaseValidation | ||||
| from django_tenants.postgresql_backend.base import DatabaseWrapper as BaseDatabaseWrapper | ||||
|  | ||||
| from authentik.lib.config import CONFIG | ||||
|  | ||||
|  | ||||
| class DatabaseValidation(BaseDatabaseValidation): | ||||
|  | ||||
|     def check(self, **kwargs): | ||||
|         return self._check_encoding() | ||||
|  | ||||
|     def _check_encoding(self): | ||||
|         """Throw a warning when the server_encoding is not UTF-8 or | ||||
|         server_encoding and client_encoding are mismatched""" | ||||
|         messages = [] | ||||
|         with self.connection.cursor() as cursor: | ||||
|             cursor.execute("SHOW server_encoding;") | ||||
|             server_encoding = cursor.fetchone()[0] | ||||
|             cursor.execute("SHOW client_encoding;") | ||||
|             client_encoding = cursor.fetchone()[0] | ||||
|             if server_encoding != client_encoding: | ||||
|                 messages.append( | ||||
|                     Warning( | ||||
|                         "PostgreSQL Server and Client encoding are mismatched: Server: " | ||||
|                         f"{server_encoding}, Client: {client_encoding}", | ||||
|                         id="ak.db.W001", | ||||
|                     ) | ||||
|                 ) | ||||
|             if server_encoding != "UTF8": | ||||
|                 messages.append( | ||||
|                     Warning( | ||||
|                         f"PostgreSQL Server encoding is not UTF8: {server_encoding}", | ||||
|                         id="ak.db.W002", | ||||
|                     ) | ||||
|                 ) | ||||
|         return messages | ||||
|  | ||||
|  | ||||
| class DatabaseWrapper(BaseDatabaseWrapper): | ||||
|     """database backend which supports rotating credentials""" | ||||
|  | ||||
|     validation_class = DatabaseValidation | ||||
|  | ||||
|     def get_connection_params(self): | ||||
|         """Refresh DB credentials before getting connection params""" | ||||
|         conn_params = super().get_connection_params() | ||||
|  | ||||
| @ -61,22 +61,6 @@ class SessionMiddleware(UpstreamSessionMiddleware): | ||||
|             pass | ||||
|         return session_key | ||||
|  | ||||
|     @staticmethod | ||||
|     def encode_session(session_key: str, user: User): | ||||
|         payload = { | ||||
|             "sid": session_key, | ||||
|             "iss": "authentik", | ||||
|             "sub": "anonymous", | ||||
|             "authenticated": user.is_authenticated, | ||||
|             "acr": ACR_AUTHENTIK_SESSION, | ||||
|         } | ||||
|         if user.is_authenticated: | ||||
|             payload["sub"] = user.uid | ||||
|         value = encode(payload=payload, key=SIGNING_HASH) | ||||
|         if settings.TEST: | ||||
|             value = session_key | ||||
|         return value | ||||
|  | ||||
|     def process_request(self, request: HttpRequest): | ||||
|         raw_session = request.COOKIES.get(settings.SESSION_COOKIE_NAME) | ||||
|         session_key = SessionMiddleware.decode_session_key(raw_session) | ||||
| @ -133,9 +117,21 @@ class SessionMiddleware(UpstreamSessionMiddleware): | ||||
|                             "request completed. The user may have logged " | ||||
|                             "out in a concurrent request, for example." | ||||
|                         ) from None | ||||
|                     payload = { | ||||
|                         "sid": request.session.session_key, | ||||
|                         "iss": "authentik", | ||||
|                         "sub": "anonymous", | ||||
|                         "authenticated": request.user.is_authenticated, | ||||
|                         "acr": ACR_AUTHENTIK_SESSION, | ||||
|                     } | ||||
|                     if request.user.is_authenticated: | ||||
|                         payload["sub"] = request.user.uid | ||||
|                     value = encode(payload=payload, key=SIGNING_HASH) | ||||
|                     if settings.TEST: | ||||
|                         value = request.session.session_key | ||||
|                     response.set_cookie( | ||||
|                         settings.SESSION_COOKIE_NAME, | ||||
|                         SessionMiddleware.encode_session(request.session.session_key, request.user), | ||||
|                         value, | ||||
|                         max_age=max_age, | ||||
|                         expires=expires, | ||||
|                         domain=settings.SESSION_COOKIE_DOMAIN, | ||||
|  | ||||
| @ -446,8 +446,6 @@ _DISALLOWED_ITEMS = [ | ||||
|     "MIDDLEWARE", | ||||
|     "AUTHENTICATION_BACKENDS", | ||||
|     "CELERY", | ||||
|     "SPECTACULAR_SETTINGS", | ||||
|     "REST_FRAMEWORK", | ||||
| ] | ||||
|  | ||||
| SILENCED_SYSTEM_CHECKS = [ | ||||
| @ -470,8 +468,6 @@ def _update_settings(app_path: str): | ||||
|         TENANT_APPS.extend(getattr(settings_module, "TENANT_APPS", [])) | ||||
|         MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", [])) | ||||
|         AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", [])) | ||||
|         SPECTACULAR_SETTINGS.update(getattr(settings_module, "SPECTACULAR_SETTINGS", {})) | ||||
|         REST_FRAMEWORK.update(getattr(settings_module, "REST_FRAMEWORK", {})) | ||||
|         CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {})) | ||||
|         for _attr in dir(settings_module): | ||||
|             if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS: | ||||
|  | ||||
| @ -1,26 +0,0 @@ | ||||
| import os | ||||
| import warnings | ||||
|  | ||||
| from cryptography.hazmat.backends.openssl.backend import backend | ||||
| from defusedxml import defuse_stdlib | ||||
|  | ||||
| from authentik.lib.config import CONFIG | ||||
|  | ||||
|  | ||||
| def setup(): | ||||
|     warnings.filterwarnings("ignore", "SelectableGroups dict interface") | ||||
|     warnings.filterwarnings( | ||||
|         "ignore", | ||||
|         "defusedxml.lxml is no longer supported and will be removed in a future release.", | ||||
|     ) | ||||
|     warnings.filterwarnings( | ||||
|         "ignore", | ||||
|         "defusedxml.cElementTree is deprecated, import from defusedxml.ElementTree instead.", | ||||
|     ) | ||||
|  | ||||
|     defuse_stdlib() | ||||
|  | ||||
|     if CONFIG.get_bool("compliance.fips.enabled", False): | ||||
|         backend._enable_fips() | ||||
|  | ||||
|     os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings") | ||||
| @ -11,8 +11,6 @@ from django.contrib.contenttypes.models import ContentType | ||||
| from django.test.runner import DiscoverRunner | ||||
| from structlog.stdlib import get_logger | ||||
|  | ||||
| from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR | ||||
| from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR | ||||
| from authentik.lib.config import CONFIG | ||||
| from authentik.lib.sentry import sentry_init | ||||
| from authentik.root.signals import post_startup, pre_startup, startup | ||||
| @ -78,9 +76,6 @@ class PytestTestRunner(DiscoverRunner):  # pragma: no cover | ||||
|         for key, value in test_config.items(): | ||||
|             CONFIG.set(key, value) | ||||
|  | ||||
|         ASN_CONTEXT_PROCESSOR.load() | ||||
|         GEOIP_CONTEXT_PROCESSOR.load() | ||||
|  | ||||
|         sentry_init() | ||||
|         self.logger.debug("Test environment configured") | ||||
|  | ||||
|  | ||||
| @ -71,31 +71,37 @@ def ldap_sync_single(source_pk: str): | ||||
|             return | ||||
|         # Delete all sync tasks from the cache | ||||
|         DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete() | ||||
|  | ||||
|         # The order of these operations needs to be preserved as each depends on the previous one(s) | ||||
|         # 1. User and group sync can happen simultaneously | ||||
|         # 2. Membership sync needs to run afterwards | ||||
|         # 3. Finally, user and group deletions can happen simultaneously | ||||
|         user_group_sync = ldap_sync_paginator(source, UserLDAPSynchronizer) + ldap_sync_paginator( | ||||
|             source, GroupLDAPSynchronizer | ||||
|         task = chain( | ||||
|             # User and group sync can happen at once, they have no dependencies on each other | ||||
|             group( | ||||
|                 ldap_sync_paginator(source, UserLDAPSynchronizer) | ||||
|                 + ldap_sync_paginator(source, GroupLDAPSynchronizer), | ||||
|             ), | ||||
|             # Membership sync needs to run afterwards | ||||
|             group( | ||||
|                 ldap_sync_paginator(source, MembershipLDAPSynchronizer), | ||||
|             ), | ||||
|             # Finally, deletions. What we'd really like to do here is something like | ||||
|             # ``` | ||||
|             # user_identifiers = <ldap query> | ||||
|             # User.objects.exclude( | ||||
|             #     usersourceconnection__identifier__in=user_uniqueness_identifiers, | ||||
|             # ).delete() | ||||
|             # ``` | ||||
|             # This runs into performance issues in large installations. So instead we spread the | ||||
|             # work out into three steps: | ||||
|             # 1. Get every object from the LDAP source. | ||||
|             # 2. Mark every object as "safe" in the database. This is quick, but any error could | ||||
|             #    mean deleting users which should not be deleted, so we do it immediately, in | ||||
|             #    large chunks, and only queue the deletion step afterwards. | ||||
|             # 3. Delete every unmarked item. This is slow, so we spread it over many tasks in | ||||
|             #    small chunks. | ||||
|             group( | ||||
|                 ldap_sync_paginator(source, UserLDAPForwardDeletion) | ||||
|                 + ldap_sync_paginator(source, GroupLDAPForwardDeletion), | ||||
|             ), | ||||
|         ) | ||||
|         membership_sync = ldap_sync_paginator(source, MembershipLDAPSynchronizer) | ||||
|         user_group_deletion = ldap_sync_paginator( | ||||
|             source, UserLDAPForwardDeletion | ||||
|         ) + ldap_sync_paginator(source, GroupLDAPForwardDeletion) | ||||
|  | ||||
|         # Celery is buggy with empty groups, so we are careful only to add non-empty groups. | ||||
|         # See https://github.com/celery/celery/issues/9772 | ||||
|         task_groups = [] | ||||
|         if user_group_sync: | ||||
|             task_groups.append(group(user_group_sync)) | ||||
|         if membership_sync: | ||||
|             task_groups.append(group(membership_sync)) | ||||
|         if user_group_deletion: | ||||
|             task_groups.append(group(user_group_deletion)) | ||||
|  | ||||
|         all_tasks = chain(task_groups) | ||||
|         all_tasks() | ||||
|         task() | ||||
|  | ||||
|  | ||||
| def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list: | ||||
|  | ||||
| @ -1,277 +0,0 @@ | ||||
| """Test SCIM Group""" | ||||
|  | ||||
| from json import dumps | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from django.urls import reverse | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from authentik.core.models import Group | ||||
| from authentik.core.tests.utils import create_test_user | ||||
| from authentik.events.models import Event, EventAction | ||||
| from authentik.lib.generators import generate_id | ||||
| from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema | ||||
| from authentik.sources.scim.models import ( | ||||
|     SCIMSource, | ||||
|     SCIMSourceGroup, | ||||
| ) | ||||
| from authentik.sources.scim.views.v2.base import SCIM_CONTENT_TYPE | ||||
|  | ||||
|  | ||||
| class TestSCIMGroups(APITestCase): | ||||
|     """Test SCIM Group view""" | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         self.source = SCIMSource.objects.create(name=generate_id(), slug=generate_id()) | ||||
|  | ||||
|     def test_group_list(self): | ||||
|         """Test full group list""" | ||||
|         response = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                 }, | ||||
|             ), | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|  | ||||
|     def test_group_list_single(self): | ||||
|         """Test full group list (single group)""" | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         user = create_test_user() | ||||
|         group.users.add(user) | ||||
|         SCIMSourceGroup.objects.create( | ||||
|             source=self.source, | ||||
|             group=group, | ||||
|             id=str(uuid4()), | ||||
|         ) | ||||
|         response = self.client.get( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                     "group_id": str(group.pk), | ||||
|                 }, | ||||
|             ), | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, second=200) | ||||
|         SCIMGroupSchema.model_validate_json(response.content, strict=True) | ||||
|  | ||||
|     def test_group_create(self): | ||||
|         """Test group create""" | ||||
|         ext_id = generate_id() | ||||
|         response = self.client.post( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                 }, | ||||
|             ), | ||||
|             data=dumps({"displayName": generate_id(), "externalId": ext_id}), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|         self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists()) | ||||
|         self.assertTrue( | ||||
|             Event.objects.filter( | ||||
|                 action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username | ||||
|             ).exists() | ||||
|         ) | ||||
|  | ||||
|     def test_group_create_members(self): | ||||
|         """Test group create""" | ||||
|         user = create_test_user() | ||||
|         ext_id = generate_id() | ||||
|         response = self.client.post( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                 }, | ||||
|             ), | ||||
|             data=dumps( | ||||
|                 { | ||||
|                     "displayName": generate_id(), | ||||
|                     "externalId": ext_id, | ||||
|                     "members": [{"value": str(user.uuid)}], | ||||
|                 } | ||||
|             ), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|         self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists()) | ||||
|         self.assertTrue( | ||||
|             Event.objects.filter( | ||||
|                 action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username | ||||
|             ).exists() | ||||
|         ) | ||||
|  | ||||
|     def test_group_create_members_empty(self): | ||||
|         """Test group create""" | ||||
|         ext_id = generate_id() | ||||
|         response = self.client.post( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                 }, | ||||
|             ), | ||||
|             data=dumps({"displayName": generate_id(), "externalId": ext_id, "members": []}), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 201) | ||||
|         self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists()) | ||||
|         self.assertTrue( | ||||
|             Event.objects.filter( | ||||
|                 action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username | ||||
|             ).exists() | ||||
|         ) | ||||
|  | ||||
|     def test_group_create_duplicate(self): | ||||
|         """Test group create (duplicate)""" | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) | ||||
|         ext_id = generate_id() | ||||
|         response = self.client.post( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                 }, | ||||
|             ), | ||||
|             data=dumps( | ||||
|                 {"displayName": generate_id(), "externalId": ext_id, "id": str(existing.group.pk)} | ||||
|             ), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 409) | ||||
|         self.assertJSONEqual( | ||||
|             response.content, | ||||
|             { | ||||
|                 "detail": "Group with ID exists already.", | ||||
|                 "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], | ||||
|                 "scimType": "uniqueness", | ||||
|                 "status": 409, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_group_update(self): | ||||
|         """Test group update""" | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) | ||||
|         ext_id = generate_id() | ||||
|         response = self.client.put( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={"source_slug": self.source.slug, "group_id": group.pk}, | ||||
|             ), | ||||
|             data=dumps( | ||||
|                 {"displayName": generate_id(), "externalId": ext_id, "id": str(existing.pk)} | ||||
|             ), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, second=200) | ||||
|  | ||||
|     def test_group_update_non_existent(self): | ||||
|         """Test group update""" | ||||
|         ext_id = generate_id() | ||||
|         response = self.client.put( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                     "group_id": str(uuid4()), | ||||
|                 }, | ||||
|             ), | ||||
|             data=dumps({"displayName": generate_id(), "externalId": ext_id, "id": ""}), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, second=404) | ||||
|         self.assertJSONEqual( | ||||
|             response.content, | ||||
|             { | ||||
|                 "detail": "Group not found.", | ||||
|                 "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"], | ||||
|                 "status": 404, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_group_patch_add(self): | ||||
|         """Test group patch""" | ||||
|         user = create_test_user() | ||||
|  | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) | ||||
|         response = self.client.patch( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={"source_slug": self.source.slug, "group_id": group.pk}, | ||||
|             ), | ||||
|             data=dumps( | ||||
|                 { | ||||
|                     "Operations": [ | ||||
|                         { | ||||
|                             "op": "Add", | ||||
|                             "path": "members", | ||||
|                             "value": {"value": str(user.uuid)}, | ||||
|                         } | ||||
|                     ] | ||||
|                 } | ||||
|             ), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, second=200) | ||||
|         self.assertTrue(group.users.filter(pk=user.pk).exists()) | ||||
|  | ||||
|     def test_group_patch_remove(self): | ||||
|         """Test group patch""" | ||||
|         user = create_test_user() | ||||
|  | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         group.users.add(user) | ||||
|         SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) | ||||
|         response = self.client.patch( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={"source_slug": self.source.slug, "group_id": group.pk}, | ||||
|             ), | ||||
|             data=dumps( | ||||
|                 { | ||||
|                     "Operations": [ | ||||
|                         { | ||||
|                             "op": "remove", | ||||
|                             "path": "members", | ||||
|                             "value": {"value": str(user.uuid)}, | ||||
|                         } | ||||
|                     ] | ||||
|                 } | ||||
|             ), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, second=200) | ||||
|         self.assertFalse(group.users.filter(pk=user.pk).exists()) | ||||
|  | ||||
|     def test_group_delete(self): | ||||
|         """Test group delete""" | ||||
|         group = Group.objects.create(name=generate_id()) | ||||
|         SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4()) | ||||
|         response = self.client.delete( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-groups", | ||||
|                 kwargs={"source_slug": self.source.slug, "group_id": group.pk}, | ||||
|             ), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, second=204) | ||||
| @ -177,51 +177,3 @@ class TestSCIMUsers(APITestCase): | ||||
|             SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"], | ||||
|             "0123456789", | ||||
|         ) | ||||
|  | ||||
|     def test_user_update(self): | ||||
|         """Test user update""" | ||||
|         user = create_test_user() | ||||
|         existing = SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4()) | ||||
|         ext_id = generate_id() | ||||
|         response = self.client.put( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-users", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                     "user_id": str(user.uuid), | ||||
|                 }, | ||||
|             ), | ||||
|             data=dumps( | ||||
|                 { | ||||
|                     "id": str(existing.pk), | ||||
|                     "userName": generate_id(), | ||||
|                     "externalId": ext_id, | ||||
|                     "emails": [ | ||||
|                         { | ||||
|                             "primary": True, | ||||
|                             "value": user.email, | ||||
|                         } | ||||
|                     ], | ||||
|                 } | ||||
|             ), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|  | ||||
|     def test_user_delete(self): | ||||
|         """Test user delete""" | ||||
|         user = create_test_user() | ||||
|         SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4()) | ||||
|         response = self.client.delete( | ||||
|             reverse( | ||||
|                 "authentik_sources_scim:v2-users", | ||||
|                 kwargs={ | ||||
|                     "source_slug": self.source.slug, | ||||
|                     "user_id": str(user.uuid), | ||||
|                 }, | ||||
|             ), | ||||
|             content_type=SCIM_CONTENT_TYPE, | ||||
|             HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}", | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 204) | ||||
|  | ||||
| @ -8,7 +8,6 @@ from rest_framework.authentication import BaseAuthentication, get_authorization_ | ||||
| from rest_framework.request import Request | ||||
| from rest_framework.views import APIView | ||||
|  | ||||
| from authentik.core.middleware import CTX_AUTH_VIA | ||||
| from authentik.core.models import Token, TokenIntents, User | ||||
| from authentik.sources.scim.models import SCIMSource | ||||
|  | ||||
| @ -27,7 +26,6 @@ class SCIMTokenAuth(BaseAuthentication): | ||||
|         _username, _, password = b64decode(key.encode()).decode().partition(":") | ||||
|         token = self.check_token(password, source_slug) | ||||
|         if token: | ||||
|             CTX_AUTH_VIA.set("scim_basic") | ||||
|             return (token.user, token) | ||||
|         return None | ||||
|  | ||||
| @ -54,5 +52,4 @@ class SCIMTokenAuth(BaseAuthentication): | ||||
|         token = self.check_token(key, source_slug) | ||||
|         if not token: | ||||
|             return None | ||||
|         CTX_AUTH_VIA.set("scim_token") | ||||
|         return (token.user, token) | ||||
|  | ||||
| @ -1,11 +1,13 @@ | ||||
| """SCIM Utils""" | ||||
|  | ||||
| from typing import Any | ||||
| from urllib.parse import urlparse | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.core.paginator import Page, Paginator | ||||
| from django.db.models import Q, QuerySet | ||||
| from django.http import HttpRequest | ||||
| from django.urls import resolve | ||||
| from rest_framework.parsers import JSONParser | ||||
| from rest_framework.permissions import IsAuthenticated | ||||
| from rest_framework.renderers import JSONRenderer | ||||
| @ -44,7 +46,7 @@ class SCIMView(APIView): | ||||
|     logger: BoundLogger | ||||
|  | ||||
|     permission_classes = [IsAuthenticated] | ||||
|     parser_classes = [SCIMParser, JSONParser] | ||||
|     parser_classes = [SCIMParser] | ||||
|     renderer_classes = [SCIMRenderer] | ||||
|  | ||||
|     def setup(self, request: HttpRequest, *args: Any, **kwargs: Any) -> None: | ||||
| @ -54,6 +56,28 @@ class SCIMView(APIView): | ||||
|     def get_authenticators(self): | ||||
|         return [SCIMTokenAuth(self)] | ||||
|  | ||||
|     def patch_resolve_value(self, raw_value: dict) -> User | Group | None: | ||||
|         """Attempt to resolve a raw `value` attribute of a patch operation into | ||||
|         a database model""" | ||||
|         model = User | ||||
|         query = {} | ||||
|         if "$ref" in raw_value: | ||||
|             url = urlparse(raw_value["$ref"]) | ||||
|             if match := resolve(url.path): | ||||
|                 if match.url_name == "v2-users": | ||||
|                     model = User | ||||
|                     query = {"pk": int(match.kwargs["user_id"])} | ||||
|         elif "type" in raw_value: | ||||
|             match raw_value["type"]: | ||||
|                 case "User": | ||||
|                     model = User | ||||
|                     query = {"pk": int(raw_value["value"])} | ||||
|                 case "Group": | ||||
|                     model = Group | ||||
|         else: | ||||
|             return None | ||||
|         return model.objects.filter(**query).first() | ||||
|  | ||||
|     def filter_parse(self, request: Request): | ||||
|         """Parse the path of a Patch Operation""" | ||||
|         path = request.query_params.get("filter") | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	