Compare commits
	
		
			2 Commits
		
	
	
		
			website/do
			...
			providers/
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 7549a6b83d | |||
| bb45b714e2 | 
@ -1,5 +1,5 @@
 | 
				
			|||||||
[bumpversion]
 | 
					[bumpversion]
 | 
				
			||||||
current_version = 2025.6.2
 | 
					current_version = 2025.6.1
 | 
				
			||||||
tag = True
 | 
					tag = True
 | 
				
			||||||
commit = True
 | 
					commit = True
 | 
				
			||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
 | 
					parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
 | 
				
			||||||
@ -21,8 +21,6 @@ optional_value = final
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
[bumpversion:file:package.json]
 | 
					[bumpversion:file:package.json]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[bumpversion:file:package-lock.json]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
[bumpversion:file:docker-compose.yml]
 | 
					[bumpversion:file:docker-compose.yml]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[bumpversion:file:schema.yml]
 | 
					[bumpversion:file:schema.yml]
 | 
				
			||||||
@ -33,4 +31,6 @@ optional_value = final
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
[bumpversion:file:internal/constants/constants.go]
 | 
					[bumpversion:file:internal/constants/constants.go]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[bumpversion:file:web/src/common/constants.ts]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[bumpversion:file:lifecycle/aws/template.yaml]
 | 
					[bumpversion:file:lifecycle/aws/template.yaml]
 | 
				
			||||||
 | 
				
			|||||||
@ -7,9 +7,6 @@ charset = utf-8
 | 
				
			|||||||
trim_trailing_whitespace = true
 | 
					trim_trailing_whitespace = true
 | 
				
			||||||
insert_final_newline = true
 | 
					insert_final_newline = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[*.toml]
 | 
					 | 
				
			||||||
indent_size = 2
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
[*.html]
 | 
					[*.html]
 | 
				
			||||||
indent_size = 2
 | 
					indent_size = 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								.github/workflows/ci-main-daily.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci-main-daily.yml
									
									
									
									
										vendored
									
									
								
							@ -15,8 +15,8 @@ jobs:
 | 
				
			|||||||
      matrix:
 | 
					      matrix:
 | 
				
			||||||
        version:
 | 
					        version:
 | 
				
			||||||
          - docs
 | 
					          - docs
 | 
				
			||||||
          - version-2025-4
 | 
					 | 
				
			||||||
          - version-2025-2
 | 
					          - version-2025-2
 | 
				
			||||||
 | 
					          - version-2024-12
 | 
				
			||||||
    steps:
 | 
					    steps:
 | 
				
			||||||
      - uses: actions/checkout@v4
 | 
					      - uses: actions/checkout@v4
 | 
				
			||||||
      - run: |
 | 
					      - run: |
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci-main.yml
									
									
									
									
										vendored
									
									
								
							@ -202,7 +202,7 @@ jobs:
 | 
				
			|||||||
        uses: actions/cache@v4
 | 
					        uses: actions/cache@v4
 | 
				
			||||||
        with:
 | 
					        with:
 | 
				
			||||||
          path: web/dist
 | 
					          path: web/dist
 | 
				
			||||||
          key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b
 | 
					          key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b
 | 
				
			||||||
      - name: prepare web ui
 | 
					      - name: prepare web ui
 | 
				
			||||||
        if: steps.cache-web.outputs.cache-hit != 'true'
 | 
					        if: steps.cache-web.outputs.cache-hit != 'true'
 | 
				
			||||||
        working-directory: web
 | 
					        working-directory: web
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										22
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										22
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							@ -41,27 +41,6 @@ jobs:
 | 
				
			|||||||
      - name: test
 | 
					      - name: test
 | 
				
			||||||
        working-directory: website/
 | 
					        working-directory: website/
 | 
				
			||||||
        run: npm test
 | 
					        run: npm test
 | 
				
			||||||
  build:
 | 
					 | 
				
			||||||
    runs-on: ubuntu-latest
 | 
					 | 
				
			||||||
    name: ${{ matrix.job }}
 | 
					 | 
				
			||||||
    strategy:
 | 
					 | 
				
			||||||
      fail-fast: false
 | 
					 | 
				
			||||||
      matrix:
 | 
					 | 
				
			||||||
        job:
 | 
					 | 
				
			||||||
          - build
 | 
					 | 
				
			||||||
          - build:integrations
 | 
					 | 
				
			||||||
    steps:
 | 
					 | 
				
			||||||
      - uses: actions/checkout@v4
 | 
					 | 
				
			||||||
      - uses: actions/setup-node@v4
 | 
					 | 
				
			||||||
        with:
 | 
					 | 
				
			||||||
          node-version-file: website/package.json
 | 
					 | 
				
			||||||
          cache: "npm"
 | 
					 | 
				
			||||||
          cache-dependency-path: website/package-lock.json
 | 
					 | 
				
			||||||
      - working-directory: website/
 | 
					 | 
				
			||||||
        run: npm ci
 | 
					 | 
				
			||||||
      - name: build
 | 
					 | 
				
			||||||
        working-directory: website/
 | 
					 | 
				
			||||||
        run: npm run ${{ matrix.job }}
 | 
					 | 
				
			||||||
  build-container:
 | 
					  build-container:
 | 
				
			||||||
    runs-on: ubuntu-latest
 | 
					    runs-on: ubuntu-latest
 | 
				
			||||||
    permissions:
 | 
					    permissions:
 | 
				
			||||||
@ -115,7 +94,6 @@ jobs:
 | 
				
			|||||||
    needs:
 | 
					    needs:
 | 
				
			||||||
      - lint
 | 
					      - lint
 | 
				
			||||||
      - test
 | 
					      - test
 | 
				
			||||||
      - build
 | 
					 | 
				
			||||||
      - build-container
 | 
					      - build-container
 | 
				
			||||||
    runs-on: ubuntu-latest
 | 
					    runs-on: ubuntu-latest
 | 
				
			||||||
    steps:
 | 
					    steps:
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								.github/workflows/codeql-analysis.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/codeql-analysis.yml
									
									
									
									
										vendored
									
									
								
							@ -2,7 +2,7 @@ name: "CodeQL"
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
on:
 | 
					on:
 | 
				
			||||||
  push:
 | 
					  push:
 | 
				
			||||||
    branches: [main, next, version*]
 | 
					    branches: [main, "*", next, version*]
 | 
				
			||||||
  pull_request:
 | 
					  pull_request:
 | 
				
			||||||
    branches: [main]
 | 
					    branches: [main]
 | 
				
			||||||
  schedule:
 | 
					  schedule:
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										4
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
								
							@ -6,15 +6,13 @@
 | 
				
			|||||||
        "!Context scalar",
 | 
					        "!Context scalar",
 | 
				
			||||||
        "!Enumerate sequence",
 | 
					        "!Enumerate sequence",
 | 
				
			||||||
        "!Env scalar",
 | 
					        "!Env scalar",
 | 
				
			||||||
        "!Env sequence",
 | 
					 | 
				
			||||||
        "!Find sequence",
 | 
					        "!Find sequence",
 | 
				
			||||||
        "!Format sequence",
 | 
					        "!Format sequence",
 | 
				
			||||||
        "!If sequence",
 | 
					        "!If sequence",
 | 
				
			||||||
        "!Index scalar",
 | 
					        "!Index scalar",
 | 
				
			||||||
        "!KeyOf scalar",
 | 
					        "!KeyOf scalar",
 | 
				
			||||||
        "!Value scalar",
 | 
					        "!Value scalar",
 | 
				
			||||||
        "!AtIndex scalar",
 | 
					        "!AtIndex scalar"
 | 
				
			||||||
        "!ParseJSON scalar"
 | 
					 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
    "typescript.preferences.importModuleSpecifier": "non-relative",
 | 
					    "typescript.preferences.importModuleSpecifier": "non-relative",
 | 
				
			||||||
    "typescript.preferences.importModuleSpecifierEnding": "index",
 | 
					    "typescript.preferences.importModuleSpecifierEnding": "index",
 | 
				
			||||||
 | 
				
			|||||||
@ -75,9 +75,9 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
 | 
				
			|||||||
    /bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
 | 
					    /bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Stage 4: Download uv
 | 
					# Stage 4: Download uv
 | 
				
			||||||
FROM ghcr.io/astral-sh/uv:0.7.14 AS uv
 | 
					FROM ghcr.io/astral-sh/uv:0.7.13 AS uv
 | 
				
			||||||
# Stage 5: Base python image
 | 
					# Stage 5: Base python image
 | 
				
			||||||
FROM ghcr.io/goauthentik/fips-python:3.13.5-slim-bookworm-fips AS python-base
 | 
					FROM ghcr.io/goauthentik/fips-python:3.13.4-slim-bookworm-fips AS python-base
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ENV VENV_PATH="/ak-root/.venv" \
 | 
					ENV VENV_PATH="/ak-root/.venv" \
 | 
				
			||||||
    PATH="/lifecycle:/ak-root/.venv/bin:$PATH" \
 | 
					    PATH="/lifecycle:/ak-root/.venv/bin:$PATH" \
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										4
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								Makefile
									
									
									
									
									
								
							@ -86,10 +86,6 @@ dev-create-db:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
dev-reset: dev-drop-db dev-create-db migrate  ## Drop and restore the Authentik PostgreSQL instance to a "fresh install" state.
 | 
					dev-reset: dev-drop-db dev-create-db migrate  ## Drop and restore the Authentik PostgreSQL instance to a "fresh install" state.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
update-test-mmdb:  ## Update test GeoIP and ASN Databases
 | 
					 | 
				
			||||||
	curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-ASN-Test.mmdb -o ${PWD}/tests/GeoLite2-ASN-Test.mmdb
 | 
					 | 
				
			||||||
	curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-City-Test.mmdb -o ${PWD}/tests/GeoLite2-City-Test.mmdb
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#########################
 | 
					#########################
 | 
				
			||||||
## API Schema
 | 
					## API Schema
 | 
				
			||||||
#########################
 | 
					#########################
 | 
				
			||||||
 | 
				
			|||||||
@ -2,7 +2,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from os import environ
 | 
					from os import environ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__version__ = "2025.6.2"
 | 
					__version__ = "2025.6.1"
 | 
				
			||||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
 | 
					ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -37,7 +37,6 @@ entries:
 | 
				
			|||||||
    - attrs:
 | 
					    - attrs:
 | 
				
			||||||
          attributes:
 | 
					          attributes:
 | 
				
			||||||
              env_null: !Env [bar-baz, null]
 | 
					              env_null: !Env [bar-baz, null]
 | 
				
			||||||
              json_parse: !ParseJSON '{"foo": "bar"}'
 | 
					 | 
				
			||||||
              policy_pk1:
 | 
					              policy_pk1:
 | 
				
			||||||
                  !Format [
 | 
					                  !Format [
 | 
				
			||||||
                      "%s-%s",
 | 
					                      "%s-%s",
 | 
				
			||||||
 | 
				
			|||||||
@ -35,6 +35,6 @@ def blueprint_tester(file_name: Path) -> Callable:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
for blueprint_file in Path("blueprints/").glob("**/*.yaml"):
 | 
					for blueprint_file in Path("blueprints/").glob("**/*.yaml"):
 | 
				
			||||||
    if "local" in str(blueprint_file) or "testing" in str(blueprint_file):
 | 
					    if "local" in str(blueprint_file):
 | 
				
			||||||
        continue
 | 
					        continue
 | 
				
			||||||
    setattr(TestPackaged, f"test_blueprint_{blueprint_file}", blueprint_tester(blueprint_file))
 | 
					    setattr(TestPackaged, f"test_blueprint_{blueprint_file}", blueprint_tester(blueprint_file))
 | 
				
			||||||
 | 
				
			|||||||
@ -215,7 +215,6 @@ class TestBlueprintsV1(TransactionTestCase):
 | 
				
			|||||||
                    },
 | 
					                    },
 | 
				
			||||||
                    "nested_context": "context-nested-value",
 | 
					                    "nested_context": "context-nested-value",
 | 
				
			||||||
                    "env_null": None,
 | 
					                    "env_null": None,
 | 
				
			||||||
                    "json_parse": {"foo": "bar"},
 | 
					 | 
				
			||||||
                    "at_index_sequence": "foo",
 | 
					                    "at_index_sequence": "foo",
 | 
				
			||||||
                    "at_index_sequence_default": "non existent",
 | 
					                    "at_index_sequence_default": "non existent",
 | 
				
			||||||
                    "at_index_mapping": 2,
 | 
					                    "at_index_mapping": 2,
 | 
				
			||||||
 | 
				
			|||||||
@ -6,7 +6,6 @@ from copy import copy
 | 
				
			|||||||
from dataclasses import asdict, dataclass, field, is_dataclass
 | 
					from dataclasses import asdict, dataclass, field, is_dataclass
 | 
				
			||||||
from enum import Enum
 | 
					from enum import Enum
 | 
				
			||||||
from functools import reduce
 | 
					from functools import reduce
 | 
				
			||||||
from json import JSONDecodeError, loads
 | 
					 | 
				
			||||||
from operator import ixor
 | 
					from operator import ixor
 | 
				
			||||||
from os import getenv
 | 
					from os import getenv
 | 
				
			||||||
from typing import Any, Literal, Union
 | 
					from typing import Any, Literal, Union
 | 
				
			||||||
@ -292,22 +291,6 @@ class Context(YAMLTag):
 | 
				
			|||||||
        return value
 | 
					        return value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ParseJSON(YAMLTag):
 | 
					 | 
				
			||||||
    """Parse JSON from context/env/etc value"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    raw: str
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None:
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
        self.raw = node.value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            return loads(self.raw)
 | 
					 | 
				
			||||||
        except JSONDecodeError as exc:
 | 
					 | 
				
			||||||
            raise EntryInvalidError.from_entry(exc, entry) from exc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Format(YAMLTag):
 | 
					class Format(YAMLTag):
 | 
				
			||||||
    """Format a string"""
 | 
					    """Format a string"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -683,7 +666,6 @@ class BlueprintLoader(SafeLoader):
 | 
				
			|||||||
        self.add_constructor("!Value", Value)
 | 
					        self.add_constructor("!Value", Value)
 | 
				
			||||||
        self.add_constructor("!Index", Index)
 | 
					        self.add_constructor("!Index", Index)
 | 
				
			||||||
        self.add_constructor("!AtIndex", AtIndex)
 | 
					        self.add_constructor("!AtIndex", AtIndex)
 | 
				
			||||||
        self.add_constructor("!ParseJSON", ParseJSON)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class EntryInvalidError(SentryIgnoredException):
 | 
					class EntryInvalidError(SentryIgnoredException):
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,8 @@
 | 
				
			|||||||
"""Authenticator Devices API Views"""
 | 
					"""Authenticator Devices API Views"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from drf_spectacular.utils import extend_schema
 | 
					from django.utils.translation import gettext_lazy as _
 | 
				
			||||||
 | 
					from drf_spectacular.types import OpenApiTypes
 | 
				
			||||||
 | 
					from drf_spectacular.utils import OpenApiParameter, extend_schema
 | 
				
			||||||
from guardian.shortcuts import get_objects_for_user
 | 
					from guardian.shortcuts import get_objects_for_user
 | 
				
			||||||
from rest_framework.fields import (
 | 
					from rest_framework.fields import (
 | 
				
			||||||
    BooleanField,
 | 
					    BooleanField,
 | 
				
			||||||
@ -13,7 +15,6 @@ from rest_framework.request import Request
 | 
				
			|||||||
from rest_framework.response import Response
 | 
					from rest_framework.response import Response
 | 
				
			||||||
from rest_framework.viewsets import ViewSet
 | 
					from rest_framework.viewsets import ViewSet
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.api.users import ParamUserSerializer
 | 
					 | 
				
			||||||
from authentik.core.api.utils import MetaNameSerializer
 | 
					from authentik.core.api.utils import MetaNameSerializer
 | 
				
			||||||
from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import EndpointDevice
 | 
					from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import EndpointDevice
 | 
				
			||||||
from authentik.stages.authenticator import device_classes, devices_for_user
 | 
					from authentik.stages.authenticator import device_classes, devices_for_user
 | 
				
			||||||
@ -22,7 +23,7 @@ from authentik.stages.authenticator_webauthn.models import WebAuthnDevice
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DeviceSerializer(MetaNameSerializer):
 | 
					class DeviceSerializer(MetaNameSerializer):
 | 
				
			||||||
    """Serializer for authenticator devices"""
 | 
					    """Serializer for Duo authenticator devices"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    pk = CharField()
 | 
					    pk = CharField()
 | 
				
			||||||
    name = CharField()
 | 
					    name = CharField()
 | 
				
			||||||
@ -32,27 +33,22 @@ class DeviceSerializer(MetaNameSerializer):
 | 
				
			|||||||
    last_updated = DateTimeField(read_only=True)
 | 
					    last_updated = DateTimeField(read_only=True)
 | 
				
			||||||
    last_used = DateTimeField(read_only=True, allow_null=True)
 | 
					    last_used = DateTimeField(read_only=True, allow_null=True)
 | 
				
			||||||
    extra_description = SerializerMethodField()
 | 
					    extra_description = SerializerMethodField()
 | 
				
			||||||
    external_id = SerializerMethodField()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_type(self, instance: Device) -> str:
 | 
					    def get_type(self, instance: Device) -> str:
 | 
				
			||||||
        """Get type of device"""
 | 
					        """Get type of device"""
 | 
				
			||||||
        return instance._meta.label
 | 
					        return instance._meta.label
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_extra_description(self, instance: Device) -> str | None:
 | 
					    def get_extra_description(self, instance: Device) -> str:
 | 
				
			||||||
        """Get extra description"""
 | 
					        """Get extra description"""
 | 
				
			||||||
        if isinstance(instance, WebAuthnDevice):
 | 
					        if isinstance(instance, WebAuthnDevice):
 | 
				
			||||||
            return instance.device_type.description if instance.device_type else None
 | 
					            return (
 | 
				
			||||||
 | 
					                instance.device_type.description
 | 
				
			||||||
 | 
					                if instance.device_type
 | 
				
			||||||
 | 
					                else _("Extra description not available")
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
        if isinstance(instance, EndpointDevice):
 | 
					        if isinstance(instance, EndpointDevice):
 | 
				
			||||||
            return instance.data.get("deviceSignals", {}).get("deviceModel")
 | 
					            return instance.data.get("deviceSignals", {}).get("deviceModel")
 | 
				
			||||||
        return None
 | 
					        return ""
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_external_id(self, instance: Device) -> str | None:
 | 
					 | 
				
			||||||
        """Get external Device ID"""
 | 
					 | 
				
			||||||
        if isinstance(instance, WebAuthnDevice):
 | 
					 | 
				
			||||||
            return instance.device_type.aaguid if instance.device_type else None
 | 
					 | 
				
			||||||
        if isinstance(instance, EndpointDevice):
 | 
					 | 
				
			||||||
            return instance.data.get("deviceSignals", {}).get("deviceModel")
 | 
					 | 
				
			||||||
        return None
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DeviceViewSet(ViewSet):
 | 
					class DeviceViewSet(ViewSet):
 | 
				
			||||||
@ -61,6 +57,7 @@ class DeviceViewSet(ViewSet):
 | 
				
			|||||||
    serializer_class = DeviceSerializer
 | 
					    serializer_class = DeviceSerializer
 | 
				
			||||||
    permission_classes = [IsAuthenticated]
 | 
					    permission_classes = [IsAuthenticated]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @extend_schema(responses={200: DeviceSerializer(many=True)})
 | 
				
			||||||
    def list(self, request: Request) -> Response:
 | 
					    def list(self, request: Request) -> Response:
 | 
				
			||||||
        """Get all devices for current user"""
 | 
					        """Get all devices for current user"""
 | 
				
			||||||
        devices = devices_for_user(request.user)
 | 
					        devices = devices_for_user(request.user)
 | 
				
			||||||
@ -82,11 +79,18 @@ class AdminDeviceViewSet(ViewSet):
 | 
				
			|||||||
            yield from device_set
 | 
					            yield from device_set
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @extend_schema(
 | 
					    @extend_schema(
 | 
				
			||||||
        parameters=[ParamUserSerializer],
 | 
					        parameters=[
 | 
				
			||||||
 | 
					            OpenApiParameter(
 | 
				
			||||||
 | 
					                name="user",
 | 
				
			||||||
 | 
					                location=OpenApiParameter.QUERY,
 | 
				
			||||||
 | 
					                type=OpenApiTypes.INT,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
        responses={200: DeviceSerializer(many=True)},
 | 
					        responses={200: DeviceSerializer(many=True)},
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    def list(self, request: Request) -> Response:
 | 
					    def list(self, request: Request) -> Response:
 | 
				
			||||||
        """Get all devices for current user"""
 | 
					        """Get all devices for current user"""
 | 
				
			||||||
        args = ParamUserSerializer(data=request.query_params)
 | 
					        kwargs = {}
 | 
				
			||||||
        args.is_valid(raise_exception=True)
 | 
					        if "user" in request.query_params:
 | 
				
			||||||
        return Response(DeviceSerializer(self.get_devices(**args.validated_data), many=True).data)
 | 
					            kwargs = {"user": request.query_params["user"]}
 | 
				
			||||||
 | 
					        return Response(DeviceSerializer(self.get_devices(**kwargs), many=True).data)
 | 
				
			||||||
 | 
				
			|||||||
@ -90,12 +90,6 @@ from authentik.stages.email.utils import TemplateEmailMessage
 | 
				
			|||||||
LOGGER = get_logger()
 | 
					LOGGER = get_logger()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ParamUserSerializer(PassiveSerializer):
 | 
					 | 
				
			||||||
    """Partial serializer for query parameters to select a user"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    user = PrimaryKeyRelatedField(queryset=User.objects.all().exclude_anonymous(), required=False)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class UserGroupSerializer(ModelSerializer):
 | 
					class UserGroupSerializer(ModelSerializer):
 | 
				
			||||||
    """Simplified Group Serializer for user's groups"""
 | 
					    """Simplified Group Serializer for user's groups"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -392,23 +386,8 @@ class UserViewSet(UsedByMixin, ModelViewSet):
 | 
				
			|||||||
    queryset = User.objects.none()
 | 
					    queryset = User.objects.none()
 | 
				
			||||||
    ordering = ["username"]
 | 
					    ordering = ["username"]
 | 
				
			||||||
    serializer_class = UserSerializer
 | 
					    serializer_class = UserSerializer
 | 
				
			||||||
    filterset_class = UsersFilter
 | 
					 | 
				
			||||||
    search_fields = ["username", "name", "is_active", "email", "uuid", "attributes"]
 | 
					    search_fields = ["username", "name", "is_active", "email", "uuid", "attributes"]
 | 
				
			||||||
 | 
					    filterset_class = UsersFilter
 | 
				
			||||||
    def get_ql_fields(self):
 | 
					 | 
				
			||||||
        from djangoql.schema import BoolField, StrField
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return [
 | 
					 | 
				
			||||||
            StrField(User, "username"),
 | 
					 | 
				
			||||||
            StrField(User, "name"),
 | 
					 | 
				
			||||||
            StrField(User, "email"),
 | 
					 | 
				
			||||||
            StrField(User, "path"),
 | 
					 | 
				
			||||||
            BoolField(User, "is_active", nullable=True),
 | 
					 | 
				
			||||||
            ChoiceSearchField(User, "type"),
 | 
					 | 
				
			||||||
            JSONSearchField(User, "attributes", suggest_nested=False),
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_queryset(self):
 | 
					    def get_queryset(self):
 | 
				
			||||||
        base_qs = User.objects.all().exclude_anonymous()
 | 
					        base_qs = User.objects.all().exclude_anonymous()
 | 
				
			||||||
 | 
				
			|||||||
@ -2,7 +2,6 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from typing import Any
 | 
					from typing import Any
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.db import models
 | 
					 | 
				
			||||||
from django.db.models import Model
 | 
					from django.db.models import Model
 | 
				
			||||||
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
 | 
					from drf_spectacular.extensions import OpenApiSerializerFieldExtension
 | 
				
			||||||
from drf_spectacular.plumbing import build_basic_type
 | 
					from drf_spectacular.plumbing import build_basic_type
 | 
				
			||||||
@ -31,27 +30,7 @@ def is_dict(value: Any):
 | 
				
			|||||||
    raise ValidationError("Value must be a dictionary, and not have any duplicate keys.")
 | 
					    raise ValidationError("Value must be a dictionary, and not have any duplicate keys.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class JSONDictField(JSONField):
 | 
					 | 
				
			||||||
    """JSON Field which only allows dictionaries"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    default_validators = [is_dict]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class JSONExtension(OpenApiSerializerFieldExtension):
 | 
					 | 
				
			||||||
    """Generate API Schema for JSON fields as"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    target_class = "authentik.core.api.utils.JSONDictField"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def map_serializer_field(self, auto_schema, direction):
 | 
					 | 
				
			||||||
        return build_basic_type(OpenApiTypes.OBJECT)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ModelSerializer(BaseModelSerializer):
 | 
					class ModelSerializer(BaseModelSerializer):
 | 
				
			||||||
 | 
					 | 
				
			||||||
    # By default, JSON fields we have are used to store dictionaries
 | 
					 | 
				
			||||||
    serializer_field_mapping = BaseModelSerializer.serializer_field_mapping.copy()
 | 
					 | 
				
			||||||
    serializer_field_mapping[models.JSONField] = JSONDictField
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def create(self, validated_data):
 | 
					    def create(self, validated_data):
 | 
				
			||||||
        instance = super().create(validated_data)
 | 
					        instance = super().create(validated_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -92,6 +71,21 @@ class ModelSerializer(BaseModelSerializer):
 | 
				
			|||||||
        return instance
 | 
					        return instance
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class JSONDictField(JSONField):
 | 
				
			||||||
 | 
					    """JSON Field which only allows dictionaries"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    default_validators = [is_dict]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class JSONExtension(OpenApiSerializerFieldExtension):
 | 
				
			||||||
 | 
					    """Generate API Schema for JSON fields as"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    target_class = "authentik.core.api.utils.JSONDictField"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def map_serializer_field(self, auto_schema, direction):
 | 
				
			||||||
 | 
					        return build_basic_type(OpenApiTypes.OBJECT)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PassiveSerializer(Serializer):
 | 
					class PassiveSerializer(Serializer):
 | 
				
			||||||
    """Base serializer class which doesn't implement create/update methods"""
 | 
					    """Base serializer class which doesn't implement create/update methods"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -13,6 +13,7 @@ class Command(TenantCommand):
 | 
				
			|||||||
        parser.add_argument("usernames", nargs="*", type=str)
 | 
					        parser.add_argument("usernames", nargs="*", type=str)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def handle_per_tenant(self, **options):
 | 
					    def handle_per_tenant(self, **options):
 | 
				
			||||||
 | 
					        print(options)
 | 
				
			||||||
        new_type = UserTypes(options["type"])
 | 
					        new_type = UserTypes(options["type"])
 | 
				
			||||||
        qs = (
 | 
					        qs = (
 | 
				
			||||||
            User.objects.exclude_anonymous()
 | 
					            User.objects.exclude_anonymous()
 | 
				
			||||||
 | 
				
			|||||||
@ -18,7 +18,7 @@ from django.http import HttpRequest
 | 
				
			|||||||
from django.utils.functional import SimpleLazyObject, cached_property
 | 
					from django.utils.functional import SimpleLazyObject, cached_property
 | 
				
			||||||
from django.utils.timezone import now
 | 
					from django.utils.timezone import now
 | 
				
			||||||
from django.utils.translation import gettext_lazy as _
 | 
					from django.utils.translation import gettext_lazy as _
 | 
				
			||||||
from django_cte import CTE, with_cte
 | 
					from django_cte import CTEQuerySet, With
 | 
				
			||||||
from guardian.conf import settings
 | 
					from guardian.conf import settings
 | 
				
			||||||
from guardian.mixins import GuardianUserMixin
 | 
					from guardian.mixins import GuardianUserMixin
 | 
				
			||||||
from model_utils.managers import InheritanceManager
 | 
					from model_utils.managers import InheritanceManager
 | 
				
			||||||
@ -136,7 +136,7 @@ class AttributesMixin(models.Model):
 | 
				
			|||||||
        return instance, False
 | 
					        return instance, False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GroupQuerySet(QuerySet):
 | 
					class GroupQuerySet(CTEQuerySet):
 | 
				
			||||||
    def with_children_recursive(self):
 | 
					    def with_children_recursive(self):
 | 
				
			||||||
        """Recursively get all groups that have the current queryset as parents
 | 
					        """Recursively get all groups that have the current queryset as parents
 | 
				
			||||||
        or are indirectly related."""
 | 
					        or are indirectly related."""
 | 
				
			||||||
@ -165,9 +165,9 @@ class GroupQuerySet(QuerySet):
 | 
				
			|||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Build the recursive query, see above
 | 
					        # Build the recursive query, see above
 | 
				
			||||||
        cte = CTE.recursive(make_cte)
 | 
					        cte = With.recursive(make_cte)
 | 
				
			||||||
        # Return the result, as a usable queryset for Group.
 | 
					        # Return the result, as a usable queryset for Group.
 | 
				
			||||||
        return with_cte(cte, select=cte.join(Group, group_uuid=cte.col.group_uuid))
 | 
					        return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Group(SerializerModel, AttributesMixin):
 | 
					class Group(SerializerModel, AttributesMixin):
 | 
				
			||||||
 | 
				
			|||||||
@ -114,7 +114,6 @@ class TestApplicationsAPI(APITestCase):
 | 
				
			|||||||
        self.assertJSONEqual(
 | 
					        self.assertJSONEqual(
 | 
				
			||||||
            response.content.decode(),
 | 
					            response.content.decode(),
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                "autocomplete": {},
 | 
					 | 
				
			||||||
                "pagination": {
 | 
					                "pagination": {
 | 
				
			||||||
                    "next": 0,
 | 
					                    "next": 0,
 | 
				
			||||||
                    "previous": 0,
 | 
					                    "previous": 0,
 | 
				
			||||||
@ -168,7 +167,6 @@ class TestApplicationsAPI(APITestCase):
 | 
				
			|||||||
        self.assertJSONEqual(
 | 
					        self.assertJSONEqual(
 | 
				
			||||||
            response.content.decode(),
 | 
					            response.content.decode(),
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                "autocomplete": {},
 | 
					 | 
				
			||||||
                "pagination": {
 | 
					                "pagination": {
 | 
				
			||||||
                    "next": 0,
 | 
					                    "next": 0,
 | 
				
			||||||
                    "previous": 0,
 | 
					                    "previous": 0,
 | 
				
			||||||
 | 
				
			|||||||
@ -119,17 +119,17 @@ class TestTrimPasswordHistory(TestCase):
 | 
				
			|||||||
            [
 | 
					            [
 | 
				
			||||||
                UserPasswordHistory(
 | 
					                UserPasswordHistory(
 | 
				
			||||||
                    user=self.user,
 | 
					                    user=self.user,
 | 
				
			||||||
                    old_password="hunter1",  # nosec
 | 
					                    old_password="hunter1",  # nosec B106
 | 
				
			||||||
                    created_at=_now - timedelta(days=3),
 | 
					                    created_at=_now - timedelta(days=3),
 | 
				
			||||||
                ),
 | 
					                ),
 | 
				
			||||||
                UserPasswordHistory(
 | 
					                UserPasswordHistory(
 | 
				
			||||||
                    user=self.user,
 | 
					                    user=self.user,
 | 
				
			||||||
                    old_password="hunter2",  # nosec
 | 
					                    old_password="hunter2",  # nosec B106
 | 
				
			||||||
                    created_at=_now - timedelta(days=2),
 | 
					                    created_at=_now - timedelta(days=2),
 | 
				
			||||||
                ),
 | 
					                ),
 | 
				
			||||||
                UserPasswordHistory(
 | 
					                UserPasswordHistory(
 | 
				
			||||||
                    user=self.user,
 | 
					                    user=self.user,
 | 
				
			||||||
                    old_password="hunter3",  # nosec
 | 
					                    old_password="hunter3",  # nosec B106
 | 
				
			||||||
                    created_at=_now,
 | 
					                    created_at=_now,
 | 
				
			||||||
                ),
 | 
					                ),
 | 
				
			||||||
            ]
 | 
					            ]
 | 
				
			||||||
 | 
				
			|||||||
@ -1,8 +1,10 @@
 | 
				
			|||||||
from hashlib import sha256
 | 
					from hashlib import sha256
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.contrib.auth.signals import user_logged_out
 | 
				
			||||||
from django.db.models import Model
 | 
					from django.db.models import Model
 | 
				
			||||||
from django.db.models.signals import post_delete, post_save, pre_delete
 | 
					from django.db.models.signals import post_delete, post_save, pre_delete
 | 
				
			||||||
from django.dispatch import receiver
 | 
					from django.dispatch import receiver
 | 
				
			||||||
 | 
					from django.http.request import HttpRequest
 | 
				
			||||||
from guardian.shortcuts import assign_perm
 | 
					from guardian.shortcuts import assign_perm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import (
 | 
					from authentik.core.models import (
 | 
				
			||||||
@ -60,6 +62,31 @@ def ssf_providers_post_save(sender: type[Model], instance: SSFProvider, created:
 | 
				
			|||||||
            instance.save()
 | 
					            instance.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@receiver(user_logged_out)
 | 
				
			||||||
 | 
					def ssf_user_logged_out_session_revoked(sender, request: HttpRequest, user: User, **_):
 | 
				
			||||||
 | 
					    """Session revoked trigger (user logged out)"""
 | 
				
			||||||
 | 
					    if not request.session or not request.session.session_key or not user:
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					    send_ssf_event(
 | 
				
			||||||
 | 
					        EventTypes.CAEP_SESSION_REVOKED,
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "initiating_entity": "user",
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        sub_id={
 | 
				
			||||||
 | 
					            "format": "complex",
 | 
				
			||||||
 | 
					            "session": {
 | 
				
			||||||
 | 
					                "format": "opaque",
 | 
				
			||||||
 | 
					                "id": sha256(request.session.session_key.encode("ascii")).hexdigest(),
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            "user": {
 | 
				
			||||||
 | 
					                "format": "email",
 | 
				
			||||||
 | 
					                "email": user.email,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        request=request,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@receiver(pre_delete, sender=AuthenticatedSession)
 | 
					@receiver(pre_delete, sender=AuthenticatedSession)
 | 
				
			||||||
def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSession, **_):
 | 
					def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSession, **_):
 | 
				
			||||||
    """Session revoked trigger (users' session has been deleted)
 | 
					    """Session revoked trigger (users' session has been deleted)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,12 +0,0 @@
 | 
				
			|||||||
"""Enterprise app config"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from authentik.enterprise.apps import EnterpriseConfig
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class AuthentikEnterpriseSearchConfig(EnterpriseConfig):
 | 
					 | 
				
			||||||
    """Enterprise app config"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    name = "authentik.enterprise.search"
 | 
					 | 
				
			||||||
    label = "authentik_search"
 | 
					 | 
				
			||||||
    verbose_name = "authentik Enterprise.Search"
 | 
					 | 
				
			||||||
    default = True
 | 
					 | 
				
			||||||
@ -1,128 +0,0 @@
 | 
				
			|||||||
"""DjangoQL search"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from collections import OrderedDict, defaultdict
 | 
					 | 
				
			||||||
from collections.abc import Generator
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from django.db import connection
 | 
					 | 
				
			||||||
from django.db.models import Model, Q
 | 
					 | 
				
			||||||
from djangoql.compat import text_type
 | 
					 | 
				
			||||||
from djangoql.schema import StrField
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class JSONSearchField(StrField):
 | 
					 | 
				
			||||||
    """JSON field for DjangoQL"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    model: Model
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, model=None, name=None, nullable=None, suggest_nested=True):
 | 
					 | 
				
			||||||
        # Set this in the constructor to not clobber the type variable
 | 
					 | 
				
			||||||
        self.type = "relation"
 | 
					 | 
				
			||||||
        self.suggest_nested = suggest_nested
 | 
					 | 
				
			||||||
        super().__init__(model, name, nullable)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_lookup(self, path, operator, value):
 | 
					 | 
				
			||||||
        search = "__".join(path)
 | 
					 | 
				
			||||||
        op, invert = self.get_operator(operator)
 | 
					 | 
				
			||||||
        q = Q(**{f"{search}{op}": self.get_lookup_value(value)})
 | 
					 | 
				
			||||||
        return ~q if invert else q
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def json_field_keys(self) -> Generator[tuple[str]]:
 | 
					 | 
				
			||||||
        with connection.cursor() as cursor:
 | 
					 | 
				
			||||||
            cursor.execute(
 | 
					 | 
				
			||||||
                f"""
 | 
					 | 
				
			||||||
                WITH RECURSIVE "{self.name}_keys" AS (
 | 
					 | 
				
			||||||
                    SELECT
 | 
					 | 
				
			||||||
                        ARRAY[jsonb_object_keys("{self.name}")] AS key_path_array,
 | 
					 | 
				
			||||||
                        "{self.name}" -> jsonb_object_keys("{self.name}") AS value
 | 
					 | 
				
			||||||
                    FROM {self.model._meta.db_table}
 | 
					 | 
				
			||||||
                    WHERE "{self.name}" IS NOT NULL
 | 
					 | 
				
			||||||
                        AND jsonb_typeof("{self.name}") = 'object'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    UNION ALL
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    SELECT
 | 
					 | 
				
			||||||
                        ck.key_path_array || jsonb_object_keys(ck.value),
 | 
					 | 
				
			||||||
                        ck.value -> jsonb_object_keys(ck.value) AS value
 | 
					 | 
				
			||||||
                    FROM "{self.name}_keys" ck
 | 
					 | 
				
			||||||
                    WHERE jsonb_typeof(ck.value) = 'object'
 | 
					 | 
				
			||||||
                ),
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                unique_paths AS (
 | 
					 | 
				
			||||||
                    SELECT DISTINCT key_path_array
 | 
					 | 
				
			||||||
                    FROM "{self.name}_keys"
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                SELECT key_path_array FROM unique_paths;
 | 
					 | 
				
			||||||
            """  # nosec
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            return (x[0] for x in cursor.fetchall())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_nested_options(self) -> OrderedDict:
 | 
					 | 
				
			||||||
        """Get keys of all nested objects to show autocomplete"""
 | 
					 | 
				
			||||||
        if not self.suggest_nested:
 | 
					 | 
				
			||||||
            return OrderedDict()
 | 
					 | 
				
			||||||
        base_model_name = f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        def recursive_function(parts: list[str], parent_parts: list[str] | None = None):
 | 
					 | 
				
			||||||
            if not parent_parts:
 | 
					 | 
				
			||||||
                parent_parts = []
 | 
					 | 
				
			||||||
            path = parts.pop(0)
 | 
					 | 
				
			||||||
            parent_parts.append(path)
 | 
					 | 
				
			||||||
            relation_key = "_".join(parent_parts)
 | 
					 | 
				
			||||||
            if len(parts) > 1:
 | 
					 | 
				
			||||||
                out_dict = {
 | 
					 | 
				
			||||||
                    relation_key: {
 | 
					 | 
				
			||||||
                        parts[0]: {
 | 
					 | 
				
			||||||
                            "type": "relation",
 | 
					 | 
				
			||||||
                            "relation": f"{relation_key}_{parts[0]}",
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
                child_paths = recursive_function(parts.copy(), parent_parts.copy())
 | 
					 | 
				
			||||||
                child_paths.update(out_dict)
 | 
					 | 
				
			||||||
                return child_paths
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                return {relation_key: {parts[0]: {}}}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        relation_structure = defaultdict(dict)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for relations in self.json_field_keys():
 | 
					 | 
				
			||||||
            result = recursive_function([base_model_name] + relations)
 | 
					 | 
				
			||||||
            for relation_key, value in result.items():
 | 
					 | 
				
			||||||
                for sub_relation_key, sub_value in value.items():
 | 
					 | 
				
			||||||
                    if not relation_structure[relation_key].get(sub_relation_key, None):
 | 
					 | 
				
			||||||
                        relation_structure[relation_key][sub_relation_key] = sub_value
 | 
					 | 
				
			||||||
                    else:
 | 
					 | 
				
			||||||
                        relation_structure[relation_key][sub_relation_key].update(sub_value)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        final_dict = defaultdict(dict)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for key, value in relation_structure.items():
 | 
					 | 
				
			||||||
            for sub_key, sub_value in value.items():
 | 
					 | 
				
			||||||
                if not sub_value:
 | 
					 | 
				
			||||||
                    final_dict[key][sub_key] = {
 | 
					 | 
				
			||||||
                        "type": "str",
 | 
					 | 
				
			||||||
                        "nullable": True,
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    final_dict[key][sub_key] = sub_value
 | 
					 | 
				
			||||||
        return OrderedDict(final_dict)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def relation(self) -> str:
 | 
					 | 
				
			||||||
        return f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ChoiceSearchField(StrField):
 | 
					 | 
				
			||||||
    def __init__(self, model=None, name=None, nullable=None):
 | 
					 | 
				
			||||||
        super().__init__(model, name, nullable, suggest_options=True)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_options(self, search):
 | 
					 | 
				
			||||||
        result = []
 | 
					 | 
				
			||||||
        choices = self._field_choices()
 | 
					 | 
				
			||||||
        if choices:
 | 
					 | 
				
			||||||
            search = search.lower()
 | 
					 | 
				
			||||||
            for c in choices:
 | 
					 | 
				
			||||||
                choice = text_type(c[0])
 | 
					 | 
				
			||||||
                if search in choice.lower():
 | 
					 | 
				
			||||||
                    result.append(choice)
 | 
					 | 
				
			||||||
        return result
 | 
					 | 
				
			||||||
@ -1,53 +0,0 @@
 | 
				
			|||||||
from rest_framework.response import Response
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from authentik.api.pagination import Pagination
 | 
					 | 
				
			||||||
from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, QLSearch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class AutocompletePagination(Pagination):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def paginate_queryset(self, queryset, request, view=None):
 | 
					 | 
				
			||||||
        self.view = view
 | 
					 | 
				
			||||||
        return super().paginate_queryset(queryset, request, view)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_autocomplete(self):
 | 
					 | 
				
			||||||
        schema = QLSearch().get_schema(self.request, self.view)
 | 
					 | 
				
			||||||
        introspections = {}
 | 
					 | 
				
			||||||
        if hasattr(self.view, "get_ql_fields"):
 | 
					 | 
				
			||||||
            from authentik.enterprise.search.schema import AKQLSchemaSerializer
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            introspections = AKQLSchemaSerializer().serialize(
 | 
					 | 
				
			||||||
                schema(self.page.paginator.object_list.model)
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        return introspections
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_paginated_response(self, data):
 | 
					 | 
				
			||||||
        previous_page_number = 0
 | 
					 | 
				
			||||||
        if self.page.has_previous():
 | 
					 | 
				
			||||||
            previous_page_number = self.page.previous_page_number()
 | 
					 | 
				
			||||||
        next_page_number = 0
 | 
					 | 
				
			||||||
        if self.page.has_next():
 | 
					 | 
				
			||||||
            next_page_number = self.page.next_page_number()
 | 
					 | 
				
			||||||
        return Response(
 | 
					 | 
				
			||||||
            {
 | 
					 | 
				
			||||||
                "pagination": {
 | 
					 | 
				
			||||||
                    "next": next_page_number,
 | 
					 | 
				
			||||||
                    "previous": previous_page_number,
 | 
					 | 
				
			||||||
                    "count": self.page.paginator.count,
 | 
					 | 
				
			||||||
                    "current": self.page.number,
 | 
					 | 
				
			||||||
                    "total_pages": self.page.paginator.num_pages,
 | 
					 | 
				
			||||||
                    "start_index": self.page.start_index(),
 | 
					 | 
				
			||||||
                    "end_index": self.page.end_index(),
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
                "results": data,
 | 
					 | 
				
			||||||
                "autocomplete": self.get_autocomplete(),
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_paginated_response_schema(self, schema):
 | 
					 | 
				
			||||||
        final_schema = super().get_paginated_response_schema(schema)
 | 
					 | 
				
			||||||
        final_schema["properties"]["autocomplete"] = {
 | 
					 | 
				
			||||||
            "$ref": f"#/components/schemas/{AUTOCOMPLETE_COMPONENT_NAME}"
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        final_schema["required"].append("autocomplete")
 | 
					 | 
				
			||||||
        return final_schema
 | 
					 | 
				
			||||||
@ -1,78 +0,0 @@
 | 
				
			|||||||
"""DjangoQL search"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from django.apps import apps
 | 
					 | 
				
			||||||
from django.db.models import QuerySet
 | 
					 | 
				
			||||||
from djangoql.ast import Name
 | 
					 | 
				
			||||||
from djangoql.exceptions import DjangoQLError
 | 
					 | 
				
			||||||
from djangoql.queryset import apply_search
 | 
					 | 
				
			||||||
from djangoql.schema import DjangoQLSchema
 | 
					 | 
				
			||||||
from rest_framework.filters import SearchFilter
 | 
					 | 
				
			||||||
from rest_framework.request import Request
 | 
					 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from authentik.enterprise.search.fields import JSONSearchField
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
LOGGER = get_logger()
 | 
					 | 
				
			||||||
AUTOCOMPLETE_COMPONENT_NAME = "Autocomplete"
 | 
					 | 
				
			||||||
AUTOCOMPLETE_SCHEMA = {
 | 
					 | 
				
			||||||
    "type": "object",
 | 
					 | 
				
			||||||
    "additionalProperties": {},
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class BaseSchema(DjangoQLSchema):
 | 
					 | 
				
			||||||
    """Base Schema which deals with JSON Fields"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def resolve_name(self, name: Name):
 | 
					 | 
				
			||||||
        model = self.model_label(self.current_model)
 | 
					 | 
				
			||||||
        root_field = name.parts[0]
 | 
					 | 
				
			||||||
        field = self.models[model].get(root_field)
 | 
					 | 
				
			||||||
        # If the query goes into a JSON field, return the root
 | 
					 | 
				
			||||||
        # field as the JSON field will do the rest
 | 
					 | 
				
			||||||
        if isinstance(field, JSONSearchField):
 | 
					 | 
				
			||||||
            # This is a workaround; build_filter will remove the right-most
 | 
					 | 
				
			||||||
            # entry in the path as that is intended to be the same as the field
 | 
					 | 
				
			||||||
            # however for JSON that is not the case
 | 
					 | 
				
			||||||
            if name.parts[-1] != root_field:
 | 
					 | 
				
			||||||
                name.parts.append(root_field)
 | 
					 | 
				
			||||||
            return field
 | 
					 | 
				
			||||||
        return super().resolve_name(name)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class QLSearch(SearchFilter):
 | 
					 | 
				
			||||||
    """rest_framework search filter which uses DjangoQL"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def enabled(self):
 | 
					 | 
				
			||||||
        return apps.get_app_config("authentik_enterprise").enabled()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_search_terms(self, request) -> str:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Search terms are set by a ?search=... query parameter,
 | 
					 | 
				
			||||||
        and may be comma and/or whitespace delimited.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        params = request.query_params.get(self.search_param, "")
 | 
					 | 
				
			||||||
        params = params.replace("\x00", "")  # strip null characters
 | 
					 | 
				
			||||||
        return params
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_schema(self, request: Request, view) -> BaseSchema:
 | 
					 | 
				
			||||||
        ql_fields = []
 | 
					 | 
				
			||||||
        if hasattr(view, "get_ql_fields"):
 | 
					 | 
				
			||||||
            ql_fields = view.get_ql_fields()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        class InlineSchema(BaseSchema):
 | 
					 | 
				
			||||||
            def get_fields(self, model):
 | 
					 | 
				
			||||||
                return ql_fields or []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return InlineSchema
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def filter_queryset(self, request: Request, queryset: QuerySet, view) -> QuerySet:
 | 
					 | 
				
			||||||
        search_query = self.get_search_terms(request)
 | 
					 | 
				
			||||||
        schema = self.get_schema(request, view)
 | 
					 | 
				
			||||||
        if len(search_query) == 0 or not self.enabled:
 | 
					 | 
				
			||||||
            return super().filter_queryset(request, queryset, view)
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            return apply_search(queryset, search_query, schema=schema)
 | 
					 | 
				
			||||||
        except DjangoQLError as exc:
 | 
					 | 
				
			||||||
            LOGGER.debug("Failed to parse search expression", exc=exc)
 | 
					 | 
				
			||||||
            return super().filter_queryset(request, queryset, view)
 | 
					 | 
				
			||||||
@ -1,29 +0,0 @@
 | 
				
			|||||||
from djangoql.serializers import DjangoQLSchemaSerializer
 | 
					 | 
				
			||||||
from drf_spectacular.generators import SchemaGenerator
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from authentik.api.schema import create_component
 | 
					 | 
				
			||||||
from authentik.enterprise.search.fields import JSONSearchField
 | 
					 | 
				
			||||||
from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class AKQLSchemaSerializer(DjangoQLSchemaSerializer):
 | 
					 | 
				
			||||||
    def serialize(self, schema):
 | 
					 | 
				
			||||||
        serialization = super().serialize(schema)
 | 
					 | 
				
			||||||
        for _, fields in schema.models.items():
 | 
					 | 
				
			||||||
            for _, field in fields.items():
 | 
					 | 
				
			||||||
                if not isinstance(field, JSONSearchField):
 | 
					 | 
				
			||||||
                    continue
 | 
					 | 
				
			||||||
                serialization["models"].update(field.get_nested_options())
 | 
					 | 
				
			||||||
        return serialization
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def serialize_field(self, field):
 | 
					 | 
				
			||||||
        result = super().serialize_field(field)
 | 
					 | 
				
			||||||
        if isinstance(field, JSONSearchField):
 | 
					 | 
				
			||||||
            result["relation"] = field.relation()
 | 
					 | 
				
			||||||
        return result
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def postprocess_schema_search_autocomplete(result, generator: SchemaGenerator, **kwargs):
 | 
					 | 
				
			||||||
    create_component(generator, AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return result
 | 
					 | 
				
			||||||
@ -1,17 +0,0 @@
 | 
				
			|||||||
SPECTACULAR_SETTINGS = {
 | 
					 | 
				
			||||||
    "POSTPROCESSING_HOOKS": [
 | 
					 | 
				
			||||||
        "authentik.api.schema.postprocess_schema_responses",
 | 
					 | 
				
			||||||
        "authentik.enterprise.search.schema.postprocess_schema_search_autocomplete",
 | 
					 | 
				
			||||||
        "drf_spectacular.hooks.postprocess_schema_enums",
 | 
					 | 
				
			||||||
    ],
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
REST_FRAMEWORK = {
 | 
					 | 
				
			||||||
    "DEFAULT_PAGINATION_CLASS": "authentik.enterprise.search.pagination.AutocompletePagination",
 | 
					 | 
				
			||||||
    "DEFAULT_FILTER_BACKENDS": [
 | 
					 | 
				
			||||||
        "authentik.enterprise.search.ql.QLSearch",
 | 
					 | 
				
			||||||
        "authentik.rbac.filters.ObjectFilter",
 | 
					 | 
				
			||||||
        "django_filters.rest_framework.DjangoFilterBackend",
 | 
					 | 
				
			||||||
        "rest_framework.filters.OrderingFilter",
 | 
					 | 
				
			||||||
    ],
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@ -1,78 +0,0 @@
 | 
				
			|||||||
from json import loads
 | 
					 | 
				
			||||||
from unittest.mock import PropertyMock, patch
 | 
					 | 
				
			||||||
from urllib.parse import urlencode
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from django.urls import reverse
 | 
					 | 
				
			||||||
from rest_framework.test import APITestCase
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from authentik.core.tests.utils import create_test_admin_user
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@patch(
 | 
					 | 
				
			||||||
    "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
 | 
					 | 
				
			||||||
    PropertyMock(return_value=True),
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
class QLTest(APITestCase):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        self.user = create_test_admin_user()
 | 
					 | 
				
			||||||
        # ensure we have more than 1 user
 | 
					 | 
				
			||||||
        create_test_admin_user()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_search(self):
 | 
					 | 
				
			||||||
        """Test simple search query"""
 | 
					 | 
				
			||||||
        self.client.force_login(self.user)
 | 
					 | 
				
			||||||
        query = f'username = "{self.user.username}"'
 | 
					 | 
				
			||||||
        res = self.client.get(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_api:user-list",
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            + f"?{urlencode({"search": query})}"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(res.status_code, 200)
 | 
					 | 
				
			||||||
        content = loads(res.content)
 | 
					 | 
				
			||||||
        self.assertEqual(content["pagination"]["count"], 1)
 | 
					 | 
				
			||||||
        self.assertEqual(content["results"][0]["username"], self.user.username)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_no_search(self):
 | 
					 | 
				
			||||||
        """Ensure works with no search query"""
 | 
					 | 
				
			||||||
        self.client.force_login(self.user)
 | 
					 | 
				
			||||||
        res = self.client.get(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_api:user-list",
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(res.status_code, 200)
 | 
					 | 
				
			||||||
        content = loads(res.content)
 | 
					 | 
				
			||||||
        self.assertNotEqual(content["pagination"]["count"], 1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_search_no_ql(self):
 | 
					 | 
				
			||||||
        """Test simple search query (no QL)"""
 | 
					 | 
				
			||||||
        self.client.force_login(self.user)
 | 
					 | 
				
			||||||
        res = self.client.get(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_api:user-list",
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            + f"?{urlencode({"search": self.user.username})}"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(res.status_code, 200)
 | 
					 | 
				
			||||||
        content = loads(res.content)
 | 
					 | 
				
			||||||
        self.assertGreaterEqual(content["pagination"]["count"], 1)
 | 
					 | 
				
			||||||
        self.assertEqual(content["results"][0]["username"], self.user.username)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_search_json(self):
 | 
					 | 
				
			||||||
        """Test search query with a JSON attribute"""
 | 
					 | 
				
			||||||
        self.user.attributes = {"foo": {"bar": "baz"}}
 | 
					 | 
				
			||||||
        self.user.save()
 | 
					 | 
				
			||||||
        self.client.force_login(self.user)
 | 
					 | 
				
			||||||
        query = 'attributes.foo.bar = "baz"'
 | 
					 | 
				
			||||||
        res = self.client.get(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_api:user-list",
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            + f"?{urlencode({"search": query})}"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(res.status_code, 200)
 | 
					 | 
				
			||||||
        content = loads(res.content)
 | 
					 | 
				
			||||||
        self.assertEqual(content["pagination"]["count"], 1)
 | 
					 | 
				
			||||||
        self.assertEqual(content["results"][0]["username"], self.user.username)
 | 
					 | 
				
			||||||
@ -18,7 +18,6 @@ TENANT_APPS = [
 | 
				
			|||||||
    "authentik.enterprise.providers.google_workspace",
 | 
					    "authentik.enterprise.providers.google_workspace",
 | 
				
			||||||
    "authentik.enterprise.providers.microsoft_entra",
 | 
					    "authentik.enterprise.providers.microsoft_entra",
 | 
				
			||||||
    "authentik.enterprise.providers.ssf",
 | 
					    "authentik.enterprise.providers.ssf",
 | 
				
			||||||
    "authentik.enterprise.search",
 | 
					 | 
				
			||||||
    "authentik.enterprise.stages.authenticator_endpoint_gdtc",
 | 
					    "authentik.enterprise.stages.authenticator_endpoint_gdtc",
 | 
				
			||||||
    "authentik.enterprise.stages.mtls",
 | 
					    "authentik.enterprise.stages.mtls",
 | 
				
			||||||
    "authentik.enterprise.stages.source",
 | 
					    "authentik.enterprise.stages.source",
 | 
				
			||||||
 | 
				
			|||||||
@ -97,7 +97,6 @@ class SourceStageFinal(StageView):
 | 
				
			|||||||
        token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN)
 | 
					        token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN)
 | 
				
			||||||
        self.logger.info("Replacing source flow with overridden flow", flow=token.flow.slug)
 | 
					        self.logger.info("Replacing source flow with overridden flow", flow=token.flow.slug)
 | 
				
			||||||
        plan = token.plan
 | 
					        plan = token.plan
 | 
				
			||||||
        plan.context.update(self.executor.plan.context)
 | 
					 | 
				
			||||||
        plan.context[PLAN_CONTEXT_IS_RESTORED] = token
 | 
					        plan.context[PLAN_CONTEXT_IS_RESTORED] = token
 | 
				
			||||||
        response = plan.to_redirect(self.request, token.flow)
 | 
					        response = plan.to_redirect(self.request, token.flow)
 | 
				
			||||||
        token.delete()
 | 
					        token.delete()
 | 
				
			||||||
 | 
				
			|||||||
@ -90,17 +90,14 @@ class TestSourceStage(FlowTestCase):
 | 
				
			|||||||
        plan: FlowPlan = session[SESSION_KEY_PLAN]
 | 
					        plan: FlowPlan = session[SESSION_KEY_PLAN]
 | 
				
			||||||
        plan.insert_stage(in_memory_stage(SourceStageFinal), index=0)
 | 
					        plan.insert_stage(in_memory_stage(SourceStageFinal), index=0)
 | 
				
			||||||
        plan.context[PLAN_CONTEXT_IS_RESTORED] = flow_token
 | 
					        plan.context[PLAN_CONTEXT_IS_RESTORED] = flow_token
 | 
				
			||||||
        plan.context["foo"] = "bar"
 | 
					 | 
				
			||||||
        session[SESSION_KEY_PLAN] = plan
 | 
					        session[SESSION_KEY_PLAN] = plan
 | 
				
			||||||
        session.save()
 | 
					        session.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Pretend we've just returned from the source
 | 
					        # Pretend we've just returned from the source
 | 
				
			||||||
        with self.assertFlowFinishes() as ff:
 | 
					        response = self.client.get(
 | 
				
			||||||
            response = self.client.get(
 | 
					            reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True
 | 
				
			||||||
                reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True
 | 
					        )
 | 
				
			||||||
            )
 | 
					        self.assertEqual(response.status_code, 200)
 | 
				
			||||||
            self.assertEqual(response.status_code, 200)
 | 
					        self.assertStageRedirects(
 | 
				
			||||||
            self.assertStageRedirects(
 | 
					            response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
 | 
				
			||||||
                response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
 | 
					        )
 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        self.assertEqual(ff().context["foo"], "bar")
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -132,22 +132,6 @@ class EventViewSet(ModelViewSet):
 | 
				
			|||||||
    ]
 | 
					    ]
 | 
				
			||||||
    filterset_class = EventsFilter
 | 
					    filterset_class = EventsFilter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_ql_fields(self):
 | 
					 | 
				
			||||||
        from djangoql.schema import DateTimeField, StrField
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return [
 | 
					 | 
				
			||||||
            ChoiceSearchField(Event, "action"),
 | 
					 | 
				
			||||||
            StrField(Event, "event_uuid"),
 | 
					 | 
				
			||||||
            StrField(Event, "app", suggest_options=True),
 | 
					 | 
				
			||||||
            StrField(Event, "client_ip"),
 | 
					 | 
				
			||||||
            JSONSearchField(Event, "user", suggest_nested=False),
 | 
					 | 
				
			||||||
            JSONSearchField(Event, "brand", suggest_nested=False),
 | 
					 | 
				
			||||||
            JSONSearchField(Event, "context", suggest_nested=False),
 | 
					 | 
				
			||||||
            DateTimeField(Event, "created", suggest_options=True),
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @extend_schema(
 | 
					    @extend_schema(
 | 
				
			||||||
        methods=["GET"],
 | 
					        methods=["GET"],
 | 
				
			||||||
        responses={200: EventTopPerUserSerializer(many=True)},
 | 
					        responses={200: EventTopPerUserSerializer(many=True)},
 | 
				
			||||||
 | 
				
			|||||||
@ -11,7 +11,7 @@ from authentik.events.models import NotificationRule
 | 
				
			|||||||
class NotificationRuleSerializer(ModelSerializer):
 | 
					class NotificationRuleSerializer(ModelSerializer):
 | 
				
			||||||
    """NotificationRule Serializer"""
 | 
					    """NotificationRule Serializer"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    destination_group_obj = GroupSerializer(read_only=True, source="destination_group")
 | 
					    group_obj = GroupSerializer(read_only=True, source="group")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class Meta:
 | 
					    class Meta:
 | 
				
			||||||
        model = NotificationRule
 | 
					        model = NotificationRule
 | 
				
			||||||
@ -20,9 +20,8 @@ class NotificationRuleSerializer(ModelSerializer):
 | 
				
			|||||||
            "name",
 | 
					            "name",
 | 
				
			||||||
            "transports",
 | 
					            "transports",
 | 
				
			||||||
            "severity",
 | 
					            "severity",
 | 
				
			||||||
            "destination_group",
 | 
					            "group",
 | 
				
			||||||
            "destination_group_obj",
 | 
					            "group_obj",
 | 
				
			||||||
            "destination_event_user",
 | 
					 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -31,6 +30,6 @@ class NotificationRuleViewSet(UsedByMixin, ModelViewSet):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    queryset = NotificationRule.objects.all()
 | 
					    queryset = NotificationRule.objects.all()
 | 
				
			||||||
    serializer_class = NotificationRuleSerializer
 | 
					    serializer_class = NotificationRuleSerializer
 | 
				
			||||||
    filterset_fields = ["name", "severity", "destination_group__name"]
 | 
					    filterset_fields = ["name", "severity", "group__name"]
 | 
				
			||||||
    ordering = ["name"]
 | 
					    ordering = ["name"]
 | 
				
			||||||
    search_fields = ["name", "destination_group__name"]
 | 
					    search_fields = ["name", "group__name"]
 | 
				
			||||||
 | 
				
			|||||||
@ -15,13 +15,13 @@ class MMDBContextProcessor(EventContextProcessor):
 | 
				
			|||||||
        self.reader: Reader | None = None
 | 
					        self.reader: Reader | None = None
 | 
				
			||||||
        self._last_mtime: float = 0.0
 | 
					        self._last_mtime: float = 0.0
 | 
				
			||||||
        self.logger = get_logger()
 | 
					        self.logger = get_logger()
 | 
				
			||||||
        self.load()
 | 
					        self.open()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def path(self) -> str | None:
 | 
					    def path(self) -> str | None:
 | 
				
			||||||
        """Get the path to the MMDB file to load"""
 | 
					        """Get the path to the MMDB file to load"""
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def load(self):
 | 
					    def open(self):
 | 
				
			||||||
        """Get GeoIP Reader, if configured, otherwise none"""
 | 
					        """Get GeoIP Reader, if configured, otherwise none"""
 | 
				
			||||||
        path = self.path()
 | 
					        path = self.path()
 | 
				
			||||||
        if path == "" or not path:
 | 
					        if path == "" or not path:
 | 
				
			||||||
@ -44,7 +44,7 @@ class MMDBContextProcessor(EventContextProcessor):
 | 
				
			|||||||
            diff = self._last_mtime < mtime
 | 
					            diff = self._last_mtime < mtime
 | 
				
			||||||
            if diff > 0:
 | 
					            if diff > 0:
 | 
				
			||||||
                self.logger.info("Found new MMDB Database, reopening", diff=diff, path=path)
 | 
					                self.logger.info("Found new MMDB Database, reopening", diff=diff, path=path)
 | 
				
			||||||
                self.load()
 | 
					                self.open()
 | 
				
			||||||
        except OSError as exc:
 | 
					        except OSError as exc:
 | 
				
			||||||
            self.logger.warning("Failed to check MMDB age", exc=exc)
 | 
					            self.logger.warning("Failed to check MMDB age", exc=exc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -19,7 +19,7 @@ from authentik.blueprints.v1.importer import excluded_models
 | 
				
			|||||||
from authentik.core.models import Group, User
 | 
					from authentik.core.models import Group, User
 | 
				
			||||||
from authentik.events.models import Event, EventAction, Notification
 | 
					from authentik.events.models import Event, EventAction, Notification
 | 
				
			||||||
from authentik.events.utils import model_to_dict
 | 
					from authentik.events.utils import model_to_dict
 | 
				
			||||||
from authentik.lib.sentry import should_ignore_exception
 | 
					from authentik.lib.sentry import before_send
 | 
				
			||||||
from authentik.lib.utils.errors import exception_to_string
 | 
					from authentik.lib.utils.errors import exception_to_string
 | 
				
			||||||
from authentik.stages.authenticator_static.models import StaticToken
 | 
					from authentik.stages.authenticator_static.models import StaticToken
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -173,7 +173,7 @@ class AuditMiddleware:
 | 
				
			|||||||
                message=exception_to_string(exception),
 | 
					                message=exception_to_string(exception),
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            thread.run()
 | 
					            thread.run()
 | 
				
			||||||
        elif not should_ignore_exception(exception):
 | 
					        elif before_send({}, {"exc_info": (None, exception, None)}) is not None:
 | 
				
			||||||
            thread = EventNewThread(
 | 
					            thread = EventNewThread(
 | 
				
			||||||
                EventAction.SYSTEM_EXCEPTION,
 | 
					                EventAction.SYSTEM_EXCEPTION,
 | 
				
			||||||
                request,
 | 
					                request,
 | 
				
			||||||
 | 
				
			|||||||
@ -1,26 +0,0 @@
 | 
				
			|||||||
# Generated by Django 5.1.11 on 2025-06-16 23:21
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from django.db import migrations, models
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Migration(migrations.Migration):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    dependencies = [
 | 
					 | 
				
			||||||
        ("authentik_events", "0009_remove_notificationtransport_webhook_mapping_and_more"),
 | 
					 | 
				
			||||||
    ]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    operations = [
 | 
					 | 
				
			||||||
        migrations.RenameField(
 | 
					 | 
				
			||||||
            model_name="notificationrule",
 | 
					 | 
				
			||||||
            old_name="group",
 | 
					 | 
				
			||||||
            new_name="destination_group",
 | 
					 | 
				
			||||||
        ),
 | 
					 | 
				
			||||||
        migrations.AddField(
 | 
					 | 
				
			||||||
            model_name="notificationrule",
 | 
					 | 
				
			||||||
            name="destination_event_user",
 | 
					 | 
				
			||||||
            field=models.BooleanField(
 | 
					 | 
				
			||||||
                default=False,
 | 
					 | 
				
			||||||
                help_text="When enabled, notification will be sent to user the user that triggered the event.When destination_group is configured, notification is sent to both.",
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
        ),
 | 
					 | 
				
			||||||
    ]
 | 
					 | 
				
			||||||
@ -1,12 +1,10 @@
 | 
				
			|||||||
"""authentik events models"""
 | 
					"""authentik events models"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from collections.abc import Generator
 | 
					 | 
				
			||||||
from datetime import timedelta
 | 
					from datetime import timedelta
 | 
				
			||||||
from difflib import get_close_matches
 | 
					from difflib import get_close_matches
 | 
				
			||||||
from functools import lru_cache
 | 
					from functools import lru_cache
 | 
				
			||||||
from inspect import currentframe
 | 
					from inspect import currentframe
 | 
				
			||||||
from smtplib import SMTPException
 | 
					from smtplib import SMTPException
 | 
				
			||||||
from typing import Any
 | 
					 | 
				
			||||||
from uuid import uuid4
 | 
					from uuid import uuid4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.apps import apps
 | 
					from django.apps import apps
 | 
				
			||||||
@ -193,32 +191,17 @@ class Event(SerializerModel, ExpiringModel):
 | 
				
			|||||||
            brand: Brand = request.brand
 | 
					            brand: Brand = request.brand
 | 
				
			||||||
            self.brand = sanitize_dict(model_to_dict(brand))
 | 
					            self.brand = sanitize_dict(model_to_dict(brand))
 | 
				
			||||||
        if hasattr(request, "user"):
 | 
					        if hasattr(request, "user"):
 | 
				
			||||||
            self.user = get_user(request.user)
 | 
					            original_user = None
 | 
				
			||||||
 | 
					            if hasattr(request, "session"):
 | 
				
			||||||
 | 
					                original_user = request.session.get(SESSION_KEY_IMPERSONATE_ORIGINAL_USER, None)
 | 
				
			||||||
 | 
					            self.user = get_user(request.user, original_user)
 | 
				
			||||||
        if user:
 | 
					        if user:
 | 
				
			||||||
            self.user = get_user(user)
 | 
					            self.user = get_user(user)
 | 
				
			||||||
 | 
					        # Check if we're currently impersonating, and add that user
 | 
				
			||||||
        if hasattr(request, "session"):
 | 
					        if hasattr(request, "session"):
 | 
				
			||||||
            from authentik.flows.views.executor import SESSION_KEY_PLAN
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Check if we're currently impersonating, and add that user
 | 
					 | 
				
			||||||
            if SESSION_KEY_IMPERSONATE_ORIGINAL_USER in request.session:
 | 
					            if SESSION_KEY_IMPERSONATE_ORIGINAL_USER in request.session:
 | 
				
			||||||
                self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER])
 | 
					                self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER])
 | 
				
			||||||
                self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER])
 | 
					                self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER])
 | 
				
			||||||
            # Special case for events that happen during a flow, the user might not be authenticated
 | 
					 | 
				
			||||||
            # yet but is a pending user instead
 | 
					 | 
				
			||||||
            if SESSION_KEY_PLAN in request.session:
 | 
					 | 
				
			||||||
                from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                plan: FlowPlan = request.session[SESSION_KEY_PLAN]
 | 
					 | 
				
			||||||
                pending_user = plan.context.get(PLAN_CONTEXT_PENDING_USER, None)
 | 
					 | 
				
			||||||
                # Only save `authenticated_as` if there's a different pending user in the flow
 | 
					 | 
				
			||||||
                # than the user that is authenticated
 | 
					 | 
				
			||||||
                if pending_user and (
 | 
					 | 
				
			||||||
                    (pending_user.pk and pending_user.pk != self.user.get("pk"))
 | 
					 | 
				
			||||||
                    or (not pending_user.pk)
 | 
					 | 
				
			||||||
                ):
 | 
					 | 
				
			||||||
                    orig_user = self.user.copy()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    self.user = {"authenticated_as": orig_user, **get_user(pending_user)}
 | 
					 | 
				
			||||||
        # User 255.255.255.255 as fallback if IP cannot be determined
 | 
					        # User 255.255.255.255 as fallback if IP cannot be determined
 | 
				
			||||||
        self.client_ip = ClientIPMiddleware.get_client_ip(request)
 | 
					        self.client_ip = ClientIPMiddleware.get_client_ip(request)
 | 
				
			||||||
        # Enrich event data
 | 
					        # Enrich event data
 | 
				
			||||||
@ -564,7 +547,7 @@ class NotificationRule(SerializerModel, PolicyBindingModel):
 | 
				
			|||||||
        default=NotificationSeverity.NOTICE,
 | 
					        default=NotificationSeverity.NOTICE,
 | 
				
			||||||
        help_text=_("Controls which severity level the created notifications will have."),
 | 
					        help_text=_("Controls which severity level the created notifications will have."),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    destination_group = models.ForeignKey(
 | 
					    group = models.ForeignKey(
 | 
				
			||||||
        Group,
 | 
					        Group,
 | 
				
			||||||
        help_text=_(
 | 
					        help_text=_(
 | 
				
			||||||
            "Define which group of users this notification should be sent and shown to. "
 | 
					            "Define which group of users this notification should be sent and shown to. "
 | 
				
			||||||
@ -574,19 +557,6 @@ class NotificationRule(SerializerModel, PolicyBindingModel):
 | 
				
			|||||||
        blank=True,
 | 
					        blank=True,
 | 
				
			||||||
        on_delete=models.SET_NULL,
 | 
					        on_delete=models.SET_NULL,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    destination_event_user = models.BooleanField(
 | 
					 | 
				
			||||||
        default=False,
 | 
					 | 
				
			||||||
        help_text=_(
 | 
					 | 
				
			||||||
            "When enabled, notification will be sent to user the user that triggered the event."
 | 
					 | 
				
			||||||
            "When destination_group is configured, notification is sent to both."
 | 
					 | 
				
			||||||
        ),
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def destination_users(self, event: Event) -> Generator[User, Any]:
 | 
					 | 
				
			||||||
        if self.destination_event_user and event.user.get("pk"):
 | 
					 | 
				
			||||||
            yield User(pk=event.user.get("pk"))
 | 
					 | 
				
			||||||
        if self.destination_group:
 | 
					 | 
				
			||||||
            yield from self.destination_group.users.all()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def serializer(self) -> type[Serializer]:
 | 
					    def serializer(self) -> type[Serializer]:
 | 
				
			||||||
 | 
				
			|||||||
@ -68,10 +68,14 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
 | 
				
			|||||||
    if not result.passing:
 | 
					    if not result.passing:
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not trigger.group:
 | 
				
			||||||
 | 
					        LOGGER.debug("e(trigger): trigger has no group", trigger=trigger)
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    LOGGER.debug("e(trigger): event trigger matched", trigger=trigger)
 | 
					    LOGGER.debug("e(trigger): event trigger matched", trigger=trigger)
 | 
				
			||||||
    # Create the notification objects
 | 
					    # Create the notification objects
 | 
				
			||||||
    for transport in trigger.transports.all():
 | 
					    for transport in trigger.transports.all():
 | 
				
			||||||
        for user in trigger.destination_users(event):
 | 
					        for user in trigger.group.users.all():
 | 
				
			||||||
            LOGGER.debug("created notification")
 | 
					            LOGGER.debug("created notification")
 | 
				
			||||||
            notification_transport.apply_async(
 | 
					            notification_transport.apply_async(
 | 
				
			||||||
                args=[
 | 
					                args=[
 | 
				
			||||||
 | 
				
			|||||||
@ -2,9 +2,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from django.test import TestCase
 | 
					from django.test import TestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.events.context_processors.base import get_context_processors
 | 
					 | 
				
			||||||
from authentik.events.context_processors.geoip import GeoIPContextProcessor
 | 
					from authentik.events.context_processors.geoip import GeoIPContextProcessor
 | 
				
			||||||
from authentik.events.models import Event, EventAction
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestGeoIP(TestCase):
 | 
					class TestGeoIP(TestCase):
 | 
				
			||||||
@ -15,7 +13,8 @@ class TestGeoIP(TestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def test_simple(self):
 | 
					    def test_simple(self):
 | 
				
			||||||
        """Test simple city wrapper"""
 | 
					        """Test simple city wrapper"""
 | 
				
			||||||
        # IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
 | 
					        # IPs from
 | 
				
			||||||
 | 
					        # https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            self.reader.city_dict("2.125.160.216"),
 | 
					            self.reader.city_dict("2.125.160.216"),
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
@ -26,12 +25,3 @@ class TestGeoIP(TestCase):
 | 
				
			|||||||
                "long": -1.25,
 | 
					                "long": -1.25,
 | 
				
			||||||
            },
 | 
					            },
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_special_chars(self):
 | 
					 | 
				
			||||||
        """Test city name with special characters"""
 | 
					 | 
				
			||||||
        # IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
 | 
					 | 
				
			||||||
        event = Event.new(EventAction.LOGIN)
 | 
					 | 
				
			||||||
        event.client_ip = "89.160.20.112"
 | 
					 | 
				
			||||||
        for processor in get_context_processors():
 | 
					 | 
				
			||||||
            processor.enrich_event(event)
 | 
					 | 
				
			||||||
        event.save()
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -8,11 +8,9 @@ from django.views.debug import SafeExceptionReporterFilter
 | 
				
			|||||||
from guardian.shortcuts import get_anonymous_user
 | 
					from guardian.shortcuts import get_anonymous_user
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.brands.models import Brand
 | 
					from authentik.brands.models import Brand
 | 
				
			||||||
from authentik.core.models import Group, User
 | 
					from authentik.core.models import Group
 | 
				
			||||||
from authentik.core.tests.utils import create_test_user
 | 
					 | 
				
			||||||
from authentik.events.models import Event
 | 
					from authentik.events.models import Event
 | 
				
			||||||
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
 | 
					from authentik.flows.views.executor import QS_QUERY
 | 
				
			||||||
from authentik.flows.views.executor import QS_QUERY, SESSION_KEY_PLAN
 | 
					 | 
				
			||||||
from authentik.lib.generators import generate_id
 | 
					from authentik.lib.generators import generate_id
 | 
				
			||||||
from authentik.policies.dummy.models import DummyPolicy
 | 
					from authentik.policies.dummy.models import DummyPolicy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -118,92 +116,3 @@ class TestEvents(TestCase):
 | 
				
			|||||||
                "pk": brand.pk.hex,
 | 
					                "pk": brand.pk.hex,
 | 
				
			||||||
            },
 | 
					            },
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_from_http_flow_pending_user(self):
 | 
					 | 
				
			||||||
        """Test request from flow request with a pending user"""
 | 
					 | 
				
			||||||
        user = create_test_user()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        session = self.client.session
 | 
					 | 
				
			||||||
        plan = FlowPlan(generate_id())
 | 
					 | 
				
			||||||
        plan.context[PLAN_CONTEXT_PENDING_USER] = user
 | 
					 | 
				
			||||||
        session[SESSION_KEY_PLAN] = plan
 | 
					 | 
				
			||||||
        session.save()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        request = self.factory.get("/")
 | 
					 | 
				
			||||||
        request.session = session
 | 
					 | 
				
			||||||
        request.user = user
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        event = Event.new("unittest").from_http(request)
 | 
					 | 
				
			||||||
        self.assertEqual(
 | 
					 | 
				
			||||||
            event.user,
 | 
					 | 
				
			||||||
            {
 | 
					 | 
				
			||||||
                "email": user.email,
 | 
					 | 
				
			||||||
                "pk": user.pk,
 | 
					 | 
				
			||||||
                "username": user.username,
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_from_http_flow_pending_user_anon(self):
 | 
					 | 
				
			||||||
        """Test request from flow request with a pending user"""
 | 
					 | 
				
			||||||
        user = create_test_user()
 | 
					 | 
				
			||||||
        anon = get_anonymous_user()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        session = self.client.session
 | 
					 | 
				
			||||||
        plan = FlowPlan(generate_id())
 | 
					 | 
				
			||||||
        plan.context[PLAN_CONTEXT_PENDING_USER] = user
 | 
					 | 
				
			||||||
        session[SESSION_KEY_PLAN] = plan
 | 
					 | 
				
			||||||
        session.save()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        request = self.factory.get("/")
 | 
					 | 
				
			||||||
        request.session = session
 | 
					 | 
				
			||||||
        request.user = anon
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        event = Event.new("unittest").from_http(request)
 | 
					 | 
				
			||||||
        self.assertEqual(
 | 
					 | 
				
			||||||
            event.user,
 | 
					 | 
				
			||||||
            {
 | 
					 | 
				
			||||||
                "authenticated_as": {
 | 
					 | 
				
			||||||
                    "pk": anon.pk,
 | 
					 | 
				
			||||||
                    "is_anonymous": True,
 | 
					 | 
				
			||||||
                    "username": "AnonymousUser",
 | 
					 | 
				
			||||||
                    "email": "",
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
                "email": user.email,
 | 
					 | 
				
			||||||
                "pk": user.pk,
 | 
					 | 
				
			||||||
                "username": user.username,
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_from_http_flow_pending_user_fake(self):
 | 
					 | 
				
			||||||
        """Test request from flow request with a pending user"""
 | 
					 | 
				
			||||||
        user = User(
 | 
					 | 
				
			||||||
            username=generate_id(),
 | 
					 | 
				
			||||||
            email=generate_id(),
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        anon = get_anonymous_user()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        session = self.client.session
 | 
					 | 
				
			||||||
        plan = FlowPlan(generate_id())
 | 
					 | 
				
			||||||
        plan.context[PLAN_CONTEXT_PENDING_USER] = user
 | 
					 | 
				
			||||||
        session[SESSION_KEY_PLAN] = plan
 | 
					 | 
				
			||||||
        session.save()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        request = self.factory.get("/")
 | 
					 | 
				
			||||||
        request.session = session
 | 
					 | 
				
			||||||
        request.user = anon
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        event = Event.new("unittest").from_http(request)
 | 
					 | 
				
			||||||
        self.assertEqual(
 | 
					 | 
				
			||||||
            event.user,
 | 
					 | 
				
			||||||
            {
 | 
					 | 
				
			||||||
                "authenticated_as": {
 | 
					 | 
				
			||||||
                    "pk": anon.pk,
 | 
					 | 
				
			||||||
                    "is_anonymous": True,
 | 
					 | 
				
			||||||
                    "username": "AnonymousUser",
 | 
					 | 
				
			||||||
                    "email": "",
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
                "email": user.email,
 | 
					 | 
				
			||||||
                "pk": user.pk,
 | 
					 | 
				
			||||||
                "username": user.username,
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -6,7 +6,6 @@ from django.urls import reverse
 | 
				
			|||||||
from rest_framework.test import APITestCase
 | 
					from rest_framework.test import APITestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import Group, User
 | 
					from authentik.core.models import Group, User
 | 
				
			||||||
from authentik.core.tests.utils import create_test_user
 | 
					 | 
				
			||||||
from authentik.events.models import (
 | 
					from authentik.events.models import (
 | 
				
			||||||
    Event,
 | 
					    Event,
 | 
				
			||||||
    EventAction,
 | 
					    EventAction,
 | 
				
			||||||
@ -35,7 +34,7 @@ class TestEventsNotifications(APITestCase):
 | 
				
			|||||||
    def test_trigger_empty(self):
 | 
					    def test_trigger_empty(self):
 | 
				
			||||||
        """Test trigger without any policies attached"""
 | 
					        """Test trigger without any policies attached"""
 | 
				
			||||||
        transport = NotificationTransport.objects.create(name=generate_id())
 | 
					        transport = NotificationTransport.objects.create(name=generate_id())
 | 
				
			||||||
        trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
 | 
					        trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
 | 
				
			||||||
        trigger.transports.add(transport)
 | 
					        trigger.transports.add(transport)
 | 
				
			||||||
        trigger.save()
 | 
					        trigger.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -47,7 +46,7 @@ class TestEventsNotifications(APITestCase):
 | 
				
			|||||||
    def test_trigger_single(self):
 | 
					    def test_trigger_single(self):
 | 
				
			||||||
        """Test simple transport triggering"""
 | 
					        """Test simple transport triggering"""
 | 
				
			||||||
        transport = NotificationTransport.objects.create(name=generate_id())
 | 
					        transport = NotificationTransport.objects.create(name=generate_id())
 | 
				
			||||||
        trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
 | 
					        trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
 | 
				
			||||||
        trigger.transports.add(transport)
 | 
					        trigger.transports.add(transport)
 | 
				
			||||||
        trigger.save()
 | 
					        trigger.save()
 | 
				
			||||||
        matcher = EventMatcherPolicy.objects.create(
 | 
					        matcher = EventMatcherPolicy.objects.create(
 | 
				
			||||||
@ -60,25 +59,6 @@ class TestEventsNotifications(APITestCase):
 | 
				
			|||||||
            Event.new(EventAction.CUSTOM_PREFIX).save()
 | 
					            Event.new(EventAction.CUSTOM_PREFIX).save()
 | 
				
			||||||
        self.assertEqual(execute_mock.call_count, 1)
 | 
					        self.assertEqual(execute_mock.call_count, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_trigger_event_user(self):
 | 
					 | 
				
			||||||
        """Test trigger with event user"""
 | 
					 | 
				
			||||||
        user = create_test_user()
 | 
					 | 
				
			||||||
        transport = NotificationTransport.objects.create(name=generate_id())
 | 
					 | 
				
			||||||
        trigger = NotificationRule.objects.create(name=generate_id(), destination_event_user=True)
 | 
					 | 
				
			||||||
        trigger.transports.add(transport)
 | 
					 | 
				
			||||||
        trigger.save()
 | 
					 | 
				
			||||||
        matcher = EventMatcherPolicy.objects.create(
 | 
					 | 
				
			||||||
            name="matcher", action=EventAction.CUSTOM_PREFIX
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        PolicyBinding.objects.create(target=trigger, policy=matcher, order=0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        execute_mock = MagicMock()
 | 
					 | 
				
			||||||
        with patch("authentik.events.models.NotificationTransport.send", execute_mock):
 | 
					 | 
				
			||||||
            Event.new(EventAction.CUSTOM_PREFIX).set_user(user).save()
 | 
					 | 
				
			||||||
        self.assertEqual(execute_mock.call_count, 1)
 | 
					 | 
				
			||||||
        notification: Notification = execute_mock.call_args[0][0]
 | 
					 | 
				
			||||||
        self.assertEqual(notification.user, user)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_trigger_no_group(self):
 | 
					    def test_trigger_no_group(self):
 | 
				
			||||||
        """Test trigger without group"""
 | 
					        """Test trigger without group"""
 | 
				
			||||||
        trigger = NotificationRule.objects.create(name=generate_id())
 | 
					        trigger = NotificationRule.objects.create(name=generate_id())
 | 
				
			||||||
@ -96,7 +76,7 @@ class TestEventsNotifications(APITestCase):
 | 
				
			|||||||
        """Test Policy error which would cause recursion"""
 | 
					        """Test Policy error which would cause recursion"""
 | 
				
			||||||
        transport = NotificationTransport.objects.create(name=generate_id())
 | 
					        transport = NotificationTransport.objects.create(name=generate_id())
 | 
				
			||||||
        NotificationRule.objects.filter(name__startswith="default").delete()
 | 
					        NotificationRule.objects.filter(name__startswith="default").delete()
 | 
				
			||||||
        trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
 | 
					        trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
 | 
				
			||||||
        trigger.transports.add(transport)
 | 
					        trigger.transports.add(transport)
 | 
				
			||||||
        trigger.save()
 | 
					        trigger.save()
 | 
				
			||||||
        matcher = EventMatcherPolicy.objects.create(
 | 
					        matcher = EventMatcherPolicy.objects.create(
 | 
				
			||||||
@ -119,7 +99,7 @@ class TestEventsNotifications(APITestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        transport = NotificationTransport.objects.create(name=generate_id(), send_once=True)
 | 
					        transport = NotificationTransport.objects.create(name=generate_id(), send_once=True)
 | 
				
			||||||
        NotificationRule.objects.filter(name__startswith="default").delete()
 | 
					        NotificationRule.objects.filter(name__startswith="default").delete()
 | 
				
			||||||
        trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
 | 
					        trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
 | 
				
			||||||
        trigger.transports.add(transport)
 | 
					        trigger.transports.add(transport)
 | 
				
			||||||
        trigger.save()
 | 
					        trigger.save()
 | 
				
			||||||
        matcher = EventMatcherPolicy.objects.create(
 | 
					        matcher = EventMatcherPolicy.objects.create(
 | 
				
			||||||
@ -143,7 +123,7 @@ class TestEventsNotifications(APITestCase):
 | 
				
			|||||||
            name=generate_id(), webhook_mapping_body=mapping, mode=TransportMode.LOCAL
 | 
					            name=generate_id(), webhook_mapping_body=mapping, mode=TransportMode.LOCAL
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        NotificationRule.objects.filter(name__startswith="default").delete()
 | 
					        NotificationRule.objects.filter(name__startswith="default").delete()
 | 
				
			||||||
        trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
 | 
					        trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
 | 
				
			||||||
        trigger.transports.add(transport)
 | 
					        trigger.transports.add(transport)
 | 
				
			||||||
        matcher = EventMatcherPolicy.objects.create(
 | 
					        matcher = EventMatcherPolicy.objects.create(
 | 
				
			||||||
            name="matcher", action=EventAction.CUSTOM_PREFIX
 | 
					            name="matcher", action=EventAction.CUSTOM_PREFIX
 | 
				
			||||||
 | 
				
			|||||||
@ -74,8 +74,8 @@ def model_to_dict(model: Model) -> dict[str, Any]:
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_user(user: User | AnonymousUser) -> dict[str, Any]:
 | 
					def get_user(user: User | AnonymousUser, original_user: User | None = None) -> dict[str, Any]:
 | 
				
			||||||
    """Convert user object to dictionary"""
 | 
					    """Convert user object to dictionary, optionally including the original user"""
 | 
				
			||||||
    if isinstance(user, AnonymousUser):
 | 
					    if isinstance(user, AnonymousUser):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            user = get_anonymous_user()
 | 
					            user = get_anonymous_user()
 | 
				
			||||||
@ -88,6 +88,10 @@ def get_user(user: User | AnonymousUser) -> dict[str, Any]:
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
    if user.username == settings.ANONYMOUS_USER_NAME:
 | 
					    if user.username == settings.ANONYMOUS_USER_NAME:
 | 
				
			||||||
        user_data["is_anonymous"] = True
 | 
					        user_data["is_anonymous"] = True
 | 
				
			||||||
 | 
					    if original_user:
 | 
				
			||||||
 | 
					        original_data = get_user(original_user)
 | 
				
			||||||
 | 
					        original_data["on_behalf_of"] = user_data
 | 
				
			||||||
 | 
					        return original_data
 | 
				
			||||||
    return user_data
 | 
					    return user_data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -4,10 +4,8 @@ from unittest.mock import MagicMock, PropertyMock, patch
 | 
				
			|||||||
from urllib.parse import urlencode
 | 
					from urllib.parse import urlencode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.http import HttpRequest, HttpResponse
 | 
					from django.http import HttpRequest, HttpResponse
 | 
				
			||||||
from django.test import override_settings
 | 
					 | 
				
			||||||
from django.test.client import RequestFactory
 | 
					from django.test.client import RequestFactory
 | 
				
			||||||
from django.urls import reverse
 | 
					from django.urls import reverse
 | 
				
			||||||
from rest_framework.exceptions import ParseError
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import Group, User
 | 
					from authentik.core.models import Group, User
 | 
				
			||||||
from authentik.core.tests.utils import create_test_flow, create_test_user
 | 
					from authentik.core.tests.utils import create_test_flow, create_test_user
 | 
				
			||||||
@ -650,25 +648,3 @@ class TestFlowExecutor(FlowTestCase):
 | 
				
			|||||||
            self.assertStageResponse(response, flow, component="ak-stage-identification")
 | 
					            self.assertStageResponse(response, flow, component="ak-stage-identification")
 | 
				
			||||||
            response = self.client.post(exec_url, {"uid_field": user_other.username}, follow=True)
 | 
					            response = self.client.post(exec_url, {"uid_field": user_other.username}, follow=True)
 | 
				
			||||||
            self.assertStageResponse(response, flow, component="ak-stage-access-denied")
 | 
					            self.assertStageResponse(response, flow, component="ak-stage-access-denied")
 | 
				
			||||||
 | 
					 | 
				
			||||||
    @patch(
 | 
					 | 
				
			||||||
        "authentik.flows.views.executor.to_stage_response",
 | 
					 | 
				
			||||||
        TO_STAGE_RESPONSE_MOCK,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    def test_invalid_json(self):
 | 
					 | 
				
			||||||
        """Test invalid JSON body"""
 | 
					 | 
				
			||||||
        flow = create_test_flow()
 | 
					 | 
				
			||||||
        FlowStageBinding.objects.create(
 | 
					 | 
				
			||||||
            target=flow, stage=DummyStage.objects.create(name=generate_id()), order=0
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        with override_settings(TEST=False, DEBUG=False):
 | 
					 | 
				
			||||||
            self.client.logout()
 | 
					 | 
				
			||||||
            response = self.client.post(url, data="{", content_type="application/json")
 | 
					 | 
				
			||||||
            self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        with self.assertRaises(ParseError):
 | 
					 | 
				
			||||||
            self.client.logout()
 | 
					 | 
				
			||||||
            response = self.client.post(url, data="{", content_type="application/json")
 | 
					 | 
				
			||||||
            self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -55,7 +55,7 @@ from authentik.flows.planner import (
 | 
				
			|||||||
    FlowPlanner,
 | 
					    FlowPlanner,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from authentik.flows.stage import AccessDeniedStage, StageView
 | 
					from authentik.flows.stage import AccessDeniedStage, StageView
 | 
				
			||||||
from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception
 | 
					from authentik.lib.sentry import SentryIgnoredException
 | 
				
			||||||
from authentik.lib.utils.errors import exception_to_string
 | 
					from authentik.lib.utils.errors import exception_to_string
 | 
				
			||||||
from authentik.lib.utils.reflection import all_subclasses, class_to_path
 | 
					from authentik.lib.utils.reflection import all_subclasses, class_to_path
 | 
				
			||||||
from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs
 | 
					from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs
 | 
				
			||||||
@ -234,13 +234,12 @@ class FlowExecutorView(APIView):
 | 
				
			|||||||
        """Handle exception in stage execution"""
 | 
					        """Handle exception in stage execution"""
 | 
				
			||||||
        if settings.DEBUG or settings.TEST:
 | 
					        if settings.DEBUG or settings.TEST:
 | 
				
			||||||
            raise exc
 | 
					            raise exc
 | 
				
			||||||
 | 
					        capture_exception(exc)
 | 
				
			||||||
        self._logger.warning(exc)
 | 
					        self._logger.warning(exc)
 | 
				
			||||||
        if not should_ignore_exception(exc):
 | 
					        Event.new(
 | 
				
			||||||
            capture_exception(exc)
 | 
					            action=EventAction.SYSTEM_EXCEPTION,
 | 
				
			||||||
            Event.new(
 | 
					            message=exception_to_string(exc),
 | 
				
			||||||
                action=EventAction.SYSTEM_EXCEPTION,
 | 
					        ).from_http(self.request)
 | 
				
			||||||
                message=exception_to_string(exc),
 | 
					 | 
				
			||||||
            ).from_http(self.request)
 | 
					 | 
				
			||||||
        challenge = FlowErrorChallenge(self.request, exc)
 | 
					        challenge = FlowErrorChallenge(self.request, exc)
 | 
				
			||||||
        challenge.is_valid(raise_exception=True)
 | 
					        challenge.is_valid(raise_exception=True)
 | 
				
			||||||
        return to_stage_response(self.request, HttpChallengeResponse(challenge))
 | 
					        return to_stage_response(self.request, HttpChallengeResponse(challenge))
 | 
				
			||||||
 | 
				
			|||||||
@ -14,7 +14,6 @@ from django_redis.exceptions import ConnectionInterrupted
 | 
				
			|||||||
from docker.errors import DockerException
 | 
					from docker.errors import DockerException
 | 
				
			||||||
from h11 import LocalProtocolError
 | 
					from h11 import LocalProtocolError
 | 
				
			||||||
from ldap3.core.exceptions import LDAPException
 | 
					from ldap3.core.exceptions import LDAPException
 | 
				
			||||||
from psycopg.errors import Error
 | 
					 | 
				
			||||||
from redis.exceptions import ConnectionError as RedisConnectionError
 | 
					from redis.exceptions import ConnectionError as RedisConnectionError
 | 
				
			||||||
from redis.exceptions import RedisError, ResponseError
 | 
					from redis.exceptions import RedisError, ResponseError
 | 
				
			||||||
from rest_framework.exceptions import APIException
 | 
					from rest_framework.exceptions import APIException
 | 
				
			||||||
@ -45,49 +44,6 @@ class SentryIgnoredException(Exception):
 | 
				
			|||||||
    """Base Class for all errors that are suppressed, and not sent to sentry."""
 | 
					    """Base Class for all errors that are suppressed, and not sent to sentry."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ignored_classes = (
 | 
					 | 
				
			||||||
    # Inbuilt types
 | 
					 | 
				
			||||||
    KeyboardInterrupt,
 | 
					 | 
				
			||||||
    ConnectionResetError,
 | 
					 | 
				
			||||||
    OSError,
 | 
					 | 
				
			||||||
    PermissionError,
 | 
					 | 
				
			||||||
    # Django Errors
 | 
					 | 
				
			||||||
    Error,
 | 
					 | 
				
			||||||
    ImproperlyConfigured,
 | 
					 | 
				
			||||||
    DatabaseError,
 | 
					 | 
				
			||||||
    OperationalError,
 | 
					 | 
				
			||||||
    InternalError,
 | 
					 | 
				
			||||||
    ProgrammingError,
 | 
					 | 
				
			||||||
    SuspiciousOperation,
 | 
					 | 
				
			||||||
    ValidationError,
 | 
					 | 
				
			||||||
    # Redis errors
 | 
					 | 
				
			||||||
    RedisConnectionError,
 | 
					 | 
				
			||||||
    ConnectionInterrupted,
 | 
					 | 
				
			||||||
    RedisError,
 | 
					 | 
				
			||||||
    ResponseError,
 | 
					 | 
				
			||||||
    # websocket errors
 | 
					 | 
				
			||||||
    ChannelFull,
 | 
					 | 
				
			||||||
    WebSocketException,
 | 
					 | 
				
			||||||
    LocalProtocolError,
 | 
					 | 
				
			||||||
    # rest_framework error
 | 
					 | 
				
			||||||
    APIException,
 | 
					 | 
				
			||||||
    # celery errors
 | 
					 | 
				
			||||||
    WorkerLostError,
 | 
					 | 
				
			||||||
    CeleryError,
 | 
					 | 
				
			||||||
    SoftTimeLimitExceeded,
 | 
					 | 
				
			||||||
    # custom baseclass
 | 
					 | 
				
			||||||
    SentryIgnoredException,
 | 
					 | 
				
			||||||
    # ldap errors
 | 
					 | 
				
			||||||
    LDAPException,
 | 
					 | 
				
			||||||
    # Docker errors
 | 
					 | 
				
			||||||
    DockerException,
 | 
					 | 
				
			||||||
    # End-user errors
 | 
					 | 
				
			||||||
    Http404,
 | 
					 | 
				
			||||||
    # AsyncIO
 | 
					 | 
				
			||||||
    CancelledError,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SentryTransport(HttpTransport):
 | 
					class SentryTransport(HttpTransport):
 | 
				
			||||||
    """Custom sentry transport with custom user-agent"""
 | 
					    """Custom sentry transport with custom user-agent"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -145,17 +101,56 @@ def traces_sampler(sampling_context: dict) -> float:
 | 
				
			|||||||
    return float(CONFIG.get("error_reporting.sample_rate", 0.1))
 | 
					    return float(CONFIG.get("error_reporting.sample_rate", 0.1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def should_ignore_exception(exc: Exception) -> bool:
 | 
					 | 
				
			||||||
    """Check if an exception should be dropped"""
 | 
					 | 
				
			||||||
    return isinstance(exc, ignored_classes)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def before_send(event: dict, hint: dict) -> dict | None:
 | 
					def before_send(event: dict, hint: dict) -> dict | None:
 | 
				
			||||||
    """Check if error is database error, and ignore if so"""
 | 
					    """Check if error is database error, and ignore if so"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    from psycopg.errors import Error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ignored_classes = (
 | 
				
			||||||
 | 
					        # Inbuilt types
 | 
				
			||||||
 | 
					        KeyboardInterrupt,
 | 
				
			||||||
 | 
					        ConnectionResetError,
 | 
				
			||||||
 | 
					        OSError,
 | 
				
			||||||
 | 
					        PermissionError,
 | 
				
			||||||
 | 
					        # Django Errors
 | 
				
			||||||
 | 
					        Error,
 | 
				
			||||||
 | 
					        ImproperlyConfigured,
 | 
				
			||||||
 | 
					        DatabaseError,
 | 
				
			||||||
 | 
					        OperationalError,
 | 
				
			||||||
 | 
					        InternalError,
 | 
				
			||||||
 | 
					        ProgrammingError,
 | 
				
			||||||
 | 
					        SuspiciousOperation,
 | 
				
			||||||
 | 
					        ValidationError,
 | 
				
			||||||
 | 
					        # Redis errors
 | 
				
			||||||
 | 
					        RedisConnectionError,
 | 
				
			||||||
 | 
					        ConnectionInterrupted,
 | 
				
			||||||
 | 
					        RedisError,
 | 
				
			||||||
 | 
					        ResponseError,
 | 
				
			||||||
 | 
					        # websocket errors
 | 
				
			||||||
 | 
					        ChannelFull,
 | 
				
			||||||
 | 
					        WebSocketException,
 | 
				
			||||||
 | 
					        LocalProtocolError,
 | 
				
			||||||
 | 
					        # rest_framework error
 | 
				
			||||||
 | 
					        APIException,
 | 
				
			||||||
 | 
					        # celery errors
 | 
				
			||||||
 | 
					        WorkerLostError,
 | 
				
			||||||
 | 
					        CeleryError,
 | 
				
			||||||
 | 
					        SoftTimeLimitExceeded,
 | 
				
			||||||
 | 
					        # custom baseclass
 | 
				
			||||||
 | 
					        SentryIgnoredException,
 | 
				
			||||||
 | 
					        # ldap errors
 | 
				
			||||||
 | 
					        LDAPException,
 | 
				
			||||||
 | 
					        # Docker errors
 | 
				
			||||||
 | 
					        DockerException,
 | 
				
			||||||
 | 
					        # End-user errors
 | 
				
			||||||
 | 
					        Http404,
 | 
				
			||||||
 | 
					        # AsyncIO
 | 
				
			||||||
 | 
					        CancelledError,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    exc_value = None
 | 
					    exc_value = None
 | 
				
			||||||
    if "exc_info" in hint:
 | 
					    if "exc_info" in hint:
 | 
				
			||||||
        _, exc_value, _ = hint["exc_info"]
 | 
					        _, exc_value, _ = hint["exc_info"]
 | 
				
			||||||
        if should_ignore_exception(exc_value):
 | 
					        if isinstance(exc_value, ignored_classes):
 | 
				
			||||||
            LOGGER.debug("dropping exception", exc=exc_value)
 | 
					            LOGGER.debug("dropping exception", exc=exc_value)
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
    if "logger" in event:
 | 
					    if "logger" in event:
 | 
				
			||||||
 | 
				
			|||||||
@ -2,7 +2,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from django.test import TestCase
 | 
					from django.test import TestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception
 | 
					from authentik.lib.sentry import SentryIgnoredException, before_send
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestSentry(TestCase):
 | 
					class TestSentry(TestCase):
 | 
				
			||||||
@ -10,8 +10,8 @@ class TestSentry(TestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def test_error_not_sent(self):
 | 
					    def test_error_not_sent(self):
 | 
				
			||||||
        """Test SentryIgnoredError not sent"""
 | 
					        """Test SentryIgnoredError not sent"""
 | 
				
			||||||
        self.assertTrue(should_ignore_exception(SentryIgnoredException()))
 | 
					        self.assertIsNone(before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_error_sent(self):
 | 
					    def test_error_sent(self):
 | 
				
			||||||
        """Test error sent"""
 | 
					        """Test error sent"""
 | 
				
			||||||
        self.assertFalse(should_ignore_exception(ValueError()))
 | 
					        self.assertEqual({}, before_send({}, {"exc_info": (0, ValueError(), 0)}))
 | 
				
			||||||
 | 
				
			|||||||
@ -1,13 +1,15 @@
 | 
				
			|||||||
"""authentik outpost signals"""
 | 
					"""authentik outpost signals"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.contrib.auth.signals import user_logged_out
 | 
				
			||||||
from django.core.cache import cache
 | 
					from django.core.cache import cache
 | 
				
			||||||
from django.db.models import Model
 | 
					from django.db.models import Model
 | 
				
			||||||
from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save
 | 
					from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save
 | 
				
			||||||
from django.dispatch import receiver
 | 
					from django.dispatch import receiver
 | 
				
			||||||
 | 
					from django.http import HttpRequest
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.brands.models import Brand
 | 
					from authentik.brands.models import Brand
 | 
				
			||||||
from authentik.core.models import AuthenticatedSession, Provider
 | 
					from authentik.core.models import AuthenticatedSession, Provider, User
 | 
				
			||||||
from authentik.crypto.models import CertificateKeyPair
 | 
					from authentik.crypto.models import CertificateKeyPair
 | 
				
			||||||
from authentik.lib.utils.reflection import class_to_path
 | 
					from authentik.lib.utils.reflection import class_to_path
 | 
				
			||||||
from authentik.outposts.models import Outpost, OutpostServiceConnection
 | 
					from authentik.outposts.models import Outpost, OutpostServiceConnection
 | 
				
			||||||
@ -80,6 +82,14 @@ def pre_delete_cleanup(sender, instance: Outpost, **_):
 | 
				
			|||||||
    outpost_controller.delay(instance.pk.hex, action="down", from_cache=True)
 | 
					    outpost_controller.delay(instance.pk.hex, action="down", from_cache=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@receiver(user_logged_out)
 | 
				
			||||||
 | 
					def logout_revoke_direct(sender: type[User], request: HttpRequest, **_):
 | 
				
			||||||
 | 
					    """Catch logout by direct logout and forward to providers"""
 | 
				
			||||||
 | 
					    if not request.session or not request.session.session_key:
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					    outpost_session_end.delay(request.session.session_key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@receiver(pre_delete, sender=AuthenticatedSession)
 | 
					@receiver(pre_delete, sender=AuthenticatedSession)
 | 
				
			||||||
def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
 | 
					def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
 | 
				
			||||||
    """Catch logout by expiring sessions being deleted"""
 | 
					    """Catch logout by expiring sessions being deleted"""
 | 
				
			||||||
 | 
				
			|||||||
@ -15,6 +15,7 @@ class OAuth2Error(SentryIgnoredException):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    error: str
 | 
					    error: str
 | 
				
			||||||
    description: str
 | 
					    description: str
 | 
				
			||||||
 | 
					    cause: str | None = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def create_dict(self):
 | 
					    def create_dict(self):
 | 
				
			||||||
        """Return error as dict for JSON Rendering"""
 | 
					        """Return error as dict for JSON Rendering"""
 | 
				
			||||||
@ -34,6 +35,10 @@ class OAuth2Error(SentryIgnoredException):
 | 
				
			|||||||
            **kwargs,
 | 
					            **kwargs,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def with_cause(self, cause: str):
 | 
				
			||||||
 | 
					        self.cause = cause
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class RedirectUriError(OAuth2Error):
 | 
					class RedirectUriError(OAuth2Error):
 | 
				
			||||||
    """The request fails due to a missing, invalid, or mismatching
 | 
					    """The request fails due to a missing, invalid, or mismatching
 | 
				
			||||||
 | 
				
			|||||||
@ -1,10 +1,23 @@
 | 
				
			|||||||
 | 
					from django.contrib.auth.signals import user_logged_out
 | 
				
			||||||
from django.db.models.signals import post_save, pre_delete
 | 
					from django.db.models.signals import post_save, pre_delete
 | 
				
			||||||
from django.dispatch import receiver
 | 
					from django.dispatch import receiver
 | 
				
			||||||
 | 
					from django.http import HttpRequest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import AuthenticatedSession, User
 | 
					from authentik.core.models import AuthenticatedSession, User
 | 
				
			||||||
from authentik.providers.oauth2.models import AccessToken, DeviceToken, RefreshToken
 | 
					from authentik.providers.oauth2.models import AccessToken, DeviceToken, RefreshToken
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@receiver(user_logged_out)
 | 
				
			||||||
 | 
					def user_logged_out_oauth_tokens_removal(sender, request: HttpRequest, user: User, **_):
 | 
				
			||||||
 | 
					    """Revoke tokens upon user logout"""
 | 
				
			||||||
 | 
					    if not request.session or not request.session.session_key:
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					    AccessToken.objects.filter(
 | 
				
			||||||
 | 
					        user=user,
 | 
				
			||||||
 | 
					        session__session__session_key=request.session.session_key,
 | 
				
			||||||
 | 
					    ).delete()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@receiver(pre_delete, sender=AuthenticatedSession)
 | 
					@receiver(pre_delete, sender=AuthenticatedSession)
 | 
				
			||||||
def user_session_deleted_oauth_tokens_removal(sender, instance: AuthenticatedSession, **_):
 | 
					def user_session_deleted_oauth_tokens_removal(sender, instance: AuthenticatedSession, **_):
 | 
				
			||||||
    """Revoke tokens upon user logout"""
 | 
					    """Revoke tokens upon user logout"""
 | 
				
			||||||
 | 
				
			|||||||
@ -12,7 +12,7 @@ from authentik.core.tests.utils import create_test_admin_user, create_test_flow
 | 
				
			|||||||
from authentik.events.models import Event, EventAction
 | 
					from authentik.events.models import Event, EventAction
 | 
				
			||||||
from authentik.lib.generators import generate_id
 | 
					from authentik.lib.generators import generate_id
 | 
				
			||||||
from authentik.lib.utils.time import timedelta_from_string
 | 
					from authentik.lib.utils.time import timedelta_from_string
 | 
				
			||||||
from authentik.providers.oauth2.constants import TOKEN_TYPE
 | 
					from authentik.providers.oauth2.constants import SCOPE_OFFLINE_ACCESS, SCOPE_OPENID, TOKEN_TYPE
 | 
				
			||||||
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
 | 
					from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
 | 
				
			||||||
from authentik.providers.oauth2.models import (
 | 
					from authentik.providers.oauth2.models import (
 | 
				
			||||||
    AccessToken,
 | 
					    AccessToken,
 | 
				
			||||||
@ -43,7 +43,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        with self.assertRaises(AuthorizeError):
 | 
					        with self.assertRaises(AuthorizeError) as cm:
 | 
				
			||||||
            request = self.factory.get(
 | 
					            request = self.factory.get(
 | 
				
			||||||
                "/",
 | 
					                "/",
 | 
				
			||||||
                data={
 | 
					                data={
 | 
				
			||||||
@ -53,6 +53,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
                },
 | 
					                },
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            OAuthAuthorizationParams.from_request(request)
 | 
					            OAuthAuthorizationParams.from_request(request)
 | 
				
			||||||
 | 
					        self.assertEqual(cm.exception.error, "unsupported_response_type")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_invalid_client_id(self):
 | 
					    def test_invalid_client_id(self):
 | 
				
			||||||
        """Test invalid client ID"""
 | 
					        """Test invalid client ID"""
 | 
				
			||||||
@ -68,7 +69,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        with self.assertRaises(AuthorizeError):
 | 
					        with self.assertRaises(AuthorizeError) as cm:
 | 
				
			||||||
            request = self.factory.get(
 | 
					            request = self.factory.get(
 | 
				
			||||||
                "/",
 | 
					                "/",
 | 
				
			||||||
                data={
 | 
					                data={
 | 
				
			||||||
@ -79,19 +80,30 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
                },
 | 
					                },
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            OAuthAuthorizationParams.from_request(request)
 | 
					            OAuthAuthorizationParams.from_request(request)
 | 
				
			||||||
 | 
					        self.assertEqual(cm.exception.error, "request_not_supported")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_invalid_redirect_uri(self):
 | 
					    def test_invalid_redirect_uri_missing(self):
 | 
				
			||||||
        """test missing/invalid redirect URI"""
 | 
					        """test missing redirect URI"""
 | 
				
			||||||
        OAuth2Provider.objects.create(
 | 
					        OAuth2Provider.objects.create(
 | 
				
			||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        with self.assertRaises(RedirectUriError):
 | 
					        with self.assertRaises(RedirectUriError) as cm:
 | 
				
			||||||
            request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
 | 
					            request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
 | 
				
			||||||
            OAuthAuthorizationParams.from_request(request)
 | 
					            OAuthAuthorizationParams.from_request(request)
 | 
				
			||||||
        with self.assertRaises(RedirectUriError):
 | 
					        self.assertEqual(cm.exception.cause, "redirect_uri_missing")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_invalid_redirect_uri(self):
 | 
				
			||||||
 | 
					        """test invalid redirect URI"""
 | 
				
			||||||
 | 
					        OAuth2Provider.objects.create(
 | 
				
			||||||
 | 
					            name=generate_id(),
 | 
				
			||||||
 | 
					            client_id="test",
 | 
				
			||||||
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        with self.assertRaises(RedirectUriError) as cm:
 | 
				
			||||||
            request = self.factory.get(
 | 
					            request = self.factory.get(
 | 
				
			||||||
                "/",
 | 
					                "/",
 | 
				
			||||||
                data={
 | 
					                data={
 | 
				
			||||||
@ -101,6 +113,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
                },
 | 
					                },
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            OAuthAuthorizationParams.from_request(request)
 | 
					            OAuthAuthorizationParams.from_request(request)
 | 
				
			||||||
 | 
					        self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_blocked_redirect_uri(self):
 | 
					    def test_blocked_redirect_uri(self):
 | 
				
			||||||
        """test missing/invalid redirect URI"""
 | 
					        """test missing/invalid redirect URI"""
 | 
				
			||||||
@ -108,9 +121,9 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:local.invalid")],
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:localhost")],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        with self.assertRaises(RedirectUriError):
 | 
					        with self.assertRaises(RedirectUriError) as cm:
 | 
				
			||||||
            request = self.factory.get(
 | 
					            request = self.factory.get(
 | 
				
			||||||
                "/",
 | 
					                "/",
 | 
				
			||||||
                data={
 | 
					                data={
 | 
				
			||||||
@ -120,6 +133,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
                },
 | 
					                },
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            OAuthAuthorizationParams.from_request(request)
 | 
					            OAuthAuthorizationParams.from_request(request)
 | 
				
			||||||
 | 
					        self.assertEqual(cm.exception.cause, "redirect_uri_forbidden_scheme")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_invalid_redirect_uri_empty(self):
 | 
					    def test_invalid_redirect_uri_empty(self):
 | 
				
			||||||
        """test missing/invalid redirect URI"""
 | 
					        """test missing/invalid redirect URI"""
 | 
				
			||||||
@ -129,9 +143,6 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris=[],
 | 
					            redirect_uris=[],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        with self.assertRaises(RedirectUriError):
 | 
					 | 
				
			||||||
            request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
 | 
					 | 
				
			||||||
            OAuthAuthorizationParams.from_request(request)
 | 
					 | 
				
			||||||
        request = self.factory.get(
 | 
					        request = self.factory.get(
 | 
				
			||||||
            "/",
 | 
					            "/",
 | 
				
			||||||
            data={
 | 
					            data={
 | 
				
			||||||
@ -150,12 +161,9 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid?")],
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "http://local.invalid?")],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        with self.assertRaises(RedirectUriError):
 | 
					        with self.assertRaises(RedirectUriError) as cm:
 | 
				
			||||||
            request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
 | 
					 | 
				
			||||||
            OAuthAuthorizationParams.from_request(request)
 | 
					 | 
				
			||||||
        with self.assertRaises(RedirectUriError):
 | 
					 | 
				
			||||||
            request = self.factory.get(
 | 
					            request = self.factory.get(
 | 
				
			||||||
                "/",
 | 
					                "/",
 | 
				
			||||||
                data={
 | 
					                data={
 | 
				
			||||||
@ -165,6 +173,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
                },
 | 
					                },
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            OAuthAuthorizationParams.from_request(request)
 | 
					            OAuthAuthorizationParams.from_request(request)
 | 
				
			||||||
 | 
					        self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_redirect_uri_invalid_regex(self):
 | 
					    def test_redirect_uri_invalid_regex(self):
 | 
				
			||||||
        """test missing/invalid redirect URI (invalid regex)"""
 | 
					        """test missing/invalid redirect URI (invalid regex)"""
 | 
				
			||||||
@ -172,12 +181,9 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "+")],
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "+")],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        with self.assertRaises(RedirectUriError):
 | 
					        with self.assertRaises(RedirectUriError) as cm:
 | 
				
			||||||
            request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
 | 
					 | 
				
			||||||
            OAuthAuthorizationParams.from_request(request)
 | 
					 | 
				
			||||||
        with self.assertRaises(RedirectUriError):
 | 
					 | 
				
			||||||
            request = self.factory.get(
 | 
					            request = self.factory.get(
 | 
				
			||||||
                "/",
 | 
					                "/",
 | 
				
			||||||
                data={
 | 
					                data={
 | 
				
			||||||
@ -187,23 +193,22 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
                },
 | 
					                },
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            OAuthAuthorizationParams.from_request(request)
 | 
					            OAuthAuthorizationParams.from_request(request)
 | 
				
			||||||
 | 
					        self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_empty_redirect_uri(self):
 | 
					    def test_redirect_uri_regex(self):
 | 
				
			||||||
        """test empty redirect URI (configure in provider)"""
 | 
					        """test valid redirect URI (regex)"""
 | 
				
			||||||
        OAuth2Provider.objects.create(
 | 
					        OAuth2Provider.objects.create(
 | 
				
			||||||
            name=generate_id(),
 | 
					            name=generate_id(),
 | 
				
			||||||
            client_id="test",
 | 
					            client_id="test",
 | 
				
			||||||
            authorization_flow=create_test_flow(),
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, ".+")],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        with self.assertRaises(RedirectUriError):
 | 
					 | 
				
			||||||
            request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
 | 
					 | 
				
			||||||
            OAuthAuthorizationParams.from_request(request)
 | 
					 | 
				
			||||||
        request = self.factory.get(
 | 
					        request = self.factory.get(
 | 
				
			||||||
            "/",
 | 
					            "/",
 | 
				
			||||||
            data={
 | 
					            data={
 | 
				
			||||||
                "response_type": "code",
 | 
					                "response_type": "code",
 | 
				
			||||||
                "client_id": "test",
 | 
					                "client_id": "test",
 | 
				
			||||||
                "redirect_uri": "http://localhost",
 | 
					                "redirect_uri": "http://foo.bar.baz",
 | 
				
			||||||
            },
 | 
					            },
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        OAuthAuthorizationParams.from_request(request)
 | 
					        OAuthAuthorizationParams.from_request(request)
 | 
				
			||||||
@ -258,7 +263,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            GrantTypes.IMPLICIT,
 | 
					            GrantTypes.IMPLICIT,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        # Implicit without openid scope
 | 
					        # Implicit without openid scope
 | 
				
			||||||
        with self.assertRaises(AuthorizeError):
 | 
					        with self.assertRaises(AuthorizeError) as cm:
 | 
				
			||||||
            request = self.factory.get(
 | 
					            request = self.factory.get(
 | 
				
			||||||
                "/",
 | 
					                "/",
 | 
				
			||||||
                data={
 | 
					                data={
 | 
				
			||||||
@ -285,7 +290,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID
 | 
					            OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        with self.assertRaises(AuthorizeError):
 | 
					        with self.assertRaises(AuthorizeError) as cm:
 | 
				
			||||||
            request = self.factory.get(
 | 
					            request = self.factory.get(
 | 
				
			||||||
                "/",
 | 
					                "/",
 | 
				
			||||||
                data={
 | 
					                data={
 | 
				
			||||||
@ -295,6 +300,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
                },
 | 
					                },
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            OAuthAuthorizationParams.from_request(request)
 | 
					            OAuthAuthorizationParams.from_request(request)
 | 
				
			||||||
 | 
					        self.assertEqual(cm.exception.error, "unsupported_response_type")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_full_code(self):
 | 
					    def test_full_code(self):
 | 
				
			||||||
        """Test full authorization"""
 | 
					        """Test full authorization"""
 | 
				
			||||||
@ -387,7 +393,8 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
            self.assertEqual(
 | 
					            self.assertEqual(
 | 
				
			||||||
                response.url,
 | 
					                response.url,
 | 
				
			||||||
                (
 | 
					                (
 | 
				
			||||||
                    f"http://localhost#id_token={provider.encode(token.id_token.to_dict())}"
 | 
					                    f"http://localhost#access_token={token.token}"
 | 
				
			||||||
 | 
					                    f"&id_token={provider.encode(token.id_token.to_dict())}"
 | 
				
			||||||
                    f"&token_type={TOKEN_TYPE}"
 | 
					                    f"&token_type={TOKEN_TYPE}"
 | 
				
			||||||
                    f"&expires_in={int(expires)}&state={state}"
 | 
					                    f"&expires_in={int(expires)}&state={state}"
 | 
				
			||||||
                ),
 | 
					                ),
 | 
				
			||||||
@ -562,6 +569,7 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
                "url": "http://localhost",
 | 
					                "url": "http://localhost",
 | 
				
			||||||
                "title": f"Redirecting to {app.name}...",
 | 
					                "title": f"Redirecting to {app.name}...",
 | 
				
			||||||
                "attrs": {
 | 
					                "attrs": {
 | 
				
			||||||
 | 
					                    "access_token": token.token,
 | 
				
			||||||
                    "id_token": provider.encode(token.id_token.to_dict()),
 | 
					                    "id_token": provider.encode(token.id_token.to_dict()),
 | 
				
			||||||
                    "token_type": TOKEN_TYPE,
 | 
					                    "token_type": TOKEN_TYPE,
 | 
				
			||||||
                    "expires_in": "3600",
 | 
					                    "expires_in": "3600",
 | 
				
			||||||
@ -613,3 +621,54 @@ class TestAuthorize(OAuthTestCase):
 | 
				
			|||||||
                },
 | 
					                },
 | 
				
			||||||
            },
 | 
					            },
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_openid_missing_invalid(self):
 | 
				
			||||||
 | 
					        """test request requiring an OpenID scope to be set"""
 | 
				
			||||||
 | 
					        OAuth2Provider.objects.create(
 | 
				
			||||||
 | 
					            name=generate_id(),
 | 
				
			||||||
 | 
					            client_id="test",
 | 
				
			||||||
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        request = self.factory.get(
 | 
				
			||||||
 | 
					            "/",
 | 
				
			||||||
 | 
					            data={
 | 
				
			||||||
 | 
					                "response_type": "id_token",
 | 
				
			||||||
 | 
					                "client_id": "test",
 | 
				
			||||||
 | 
					                "redirect_uri": "http://localhost",
 | 
				
			||||||
 | 
					                "scope": "",
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        with self.assertRaises(AuthorizeError) as cm:
 | 
				
			||||||
 | 
					            OAuthAuthorizationParams.from_request(request)
 | 
				
			||||||
 | 
					        self.assertEqual(cm.exception.cause, "scope_openid_missing")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @apply_blueprint("system/providers-oauth2.yaml")
 | 
				
			||||||
 | 
					    def test_offline_access_invalid(self):
 | 
				
			||||||
 | 
					        """test request for offline_access with invalid response type"""
 | 
				
			||||||
 | 
					        provider = OAuth2Provider.objects.create(
 | 
				
			||||||
 | 
					            name=generate_id(),
 | 
				
			||||||
 | 
					            client_id="test",
 | 
				
			||||||
 | 
					            authorization_flow=create_test_flow(),
 | 
				
			||||||
 | 
					            redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        provider.property_mappings.set(
 | 
				
			||||||
 | 
					            ScopeMapping.objects.filter(
 | 
				
			||||||
 | 
					                managed__in=[
 | 
				
			||||||
 | 
					                    "goauthentik.io/providers/oauth2/scope-openid",
 | 
				
			||||||
 | 
					                    "goauthentik.io/providers/oauth2/scope-offline_access",
 | 
				
			||||||
 | 
					                ]
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        request = self.factory.get(
 | 
				
			||||||
 | 
					            "/",
 | 
				
			||||||
 | 
					            data={
 | 
				
			||||||
 | 
					                "response_type": "id_token",
 | 
				
			||||||
 | 
					                "client_id": "test",
 | 
				
			||||||
 | 
					                "redirect_uri": "http://localhost",
 | 
				
			||||||
 | 
					                "scope": f"{SCOPE_OPENID} {SCOPE_OFFLINE_ACCESS}",
 | 
				
			||||||
 | 
					                "nonce": generate_id(),
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        parsed = OAuthAuthorizationParams.from_request(request)
 | 
				
			||||||
 | 
					        self.assertNotIn(SCOPE_OFFLINE_ACCESS, parsed.scope)
 | 
				
			||||||
 | 
				
			|||||||
@ -150,12 +150,12 @@ class OAuthAuthorizationParams:
 | 
				
			|||||||
        self.check_redirect_uri()
 | 
					        self.check_redirect_uri()
 | 
				
			||||||
        self.check_grant()
 | 
					        self.check_grant()
 | 
				
			||||||
        self.check_scope(github_compat)
 | 
					        self.check_scope(github_compat)
 | 
				
			||||||
 | 
					        self.check_nonce()
 | 
				
			||||||
 | 
					        self.check_code_challenge()
 | 
				
			||||||
        if self.request:
 | 
					        if self.request:
 | 
				
			||||||
            raise AuthorizeError(
 | 
					            raise AuthorizeError(
 | 
				
			||||||
                self.redirect_uri, "request_not_supported", self.grant_type, self.state
 | 
					                self.redirect_uri, "request_not_supported", self.grant_type, self.state
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        self.check_nonce()
 | 
					 | 
				
			||||||
        self.check_code_challenge()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def check_grant(self):
 | 
					    def check_grant(self):
 | 
				
			||||||
        """Check grant"""
 | 
					        """Check grant"""
 | 
				
			||||||
@ -190,7 +190,7 @@ class OAuthAuthorizationParams:
 | 
				
			|||||||
        allowed_redirect_urls = self.provider.redirect_uris
 | 
					        allowed_redirect_urls = self.provider.redirect_uris
 | 
				
			||||||
        if not self.redirect_uri:
 | 
					        if not self.redirect_uri:
 | 
				
			||||||
            LOGGER.warning("Missing redirect uri.")
 | 
					            LOGGER.warning("Missing redirect uri.")
 | 
				
			||||||
            raise RedirectUriError("", allowed_redirect_urls)
 | 
					            raise RedirectUriError("", allowed_redirect_urls).with_cause("redirect_uri_missing")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if len(allowed_redirect_urls) < 1:
 | 
					        if len(allowed_redirect_urls) < 1:
 | 
				
			||||||
            LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
 | 
					            LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
 | 
				
			||||||
@ -219,10 +219,14 @@ class OAuthAuthorizationParams:
 | 
				
			|||||||
                        provider=self.provider,
 | 
					                        provider=self.provider,
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
        if not match_found:
 | 
					        if not match_found:
 | 
				
			||||||
            raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
 | 
					            raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
 | 
				
			||||||
 | 
					                "redirect_uri_no_match"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
        # Check against forbidden schemes
 | 
					        # Check against forbidden schemes
 | 
				
			||||||
        if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
 | 
					        if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
 | 
				
			||||||
            raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
 | 
					            raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
 | 
				
			||||||
 | 
					                "redirect_uri_forbidden_scheme"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def check_scope(self, github_compat=False):
 | 
					    def check_scope(self, github_compat=False):
 | 
				
			||||||
        """Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
 | 
					        """Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
 | 
				
			||||||
@ -251,7 +255,9 @@ class OAuthAuthorizationParams:
 | 
				
			|||||||
            or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
 | 
					            or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
 | 
				
			||||||
        ):
 | 
					        ):
 | 
				
			||||||
            LOGGER.warning("Missing 'openid' scope.")
 | 
					            LOGGER.warning("Missing 'openid' scope.")
 | 
				
			||||||
            raise AuthorizeError(self.redirect_uri, "invalid_scope", self.grant_type, self.state)
 | 
					            raise AuthorizeError(
 | 
				
			||||||
 | 
					                self.redirect_uri, "invalid_scope", self.grant_type, self.state
 | 
				
			||||||
 | 
					            ).with_cause("scope_openid_missing")
 | 
				
			||||||
        if SCOPE_OFFLINE_ACCESS in self.scope:
 | 
					        if SCOPE_OFFLINE_ACCESS in self.scope:
 | 
				
			||||||
            # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
 | 
					            # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
 | 
				
			||||||
            # Don't explicitly request consent with offline_access, as the spec allows for
 | 
					            # Don't explicitly request consent with offline_access, as the spec allows for
 | 
				
			||||||
@ -286,7 +292,9 @@ class OAuthAuthorizationParams:
 | 
				
			|||||||
            return
 | 
					            return
 | 
				
			||||||
        if not self.nonce:
 | 
					        if not self.nonce:
 | 
				
			||||||
            LOGGER.warning("Missing nonce for OpenID Request")
 | 
					            LOGGER.warning("Missing nonce for OpenID Request")
 | 
				
			||||||
            raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state)
 | 
					            raise AuthorizeError(
 | 
				
			||||||
 | 
					                self.redirect_uri, "invalid_request", self.grant_type, self.state
 | 
				
			||||||
 | 
					            ).with_cause("none_missing")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def check_code_challenge(self):
 | 
					    def check_code_challenge(self):
 | 
				
			||||||
        """PKCE validation of the transformation method."""
 | 
					        """PKCE validation of the transformation method."""
 | 
				
			||||||
@ -345,10 +353,10 @@ class AuthorizationFlowInitView(PolicyAccessView):
 | 
				
			|||||||
                self.request, github_compat=self.github_compat
 | 
					                self.request, github_compat=self.github_compat
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        except AuthorizeError as error:
 | 
					        except AuthorizeError as error:
 | 
				
			||||||
            LOGGER.warning(error.description, redirect_uri=error.redirect_uri)
 | 
					            LOGGER.warning(error.description, redirect_uri=error.redirect_uri, cause=error.cause)
 | 
				
			||||||
            raise RequestValidationError(error.get_response(self.request)) from None
 | 
					            raise RequestValidationError(error.get_response(self.request)) from None
 | 
				
			||||||
        except OAuth2Error as error:
 | 
					        except OAuth2Error as error:
 | 
				
			||||||
            LOGGER.warning(error.description)
 | 
					            LOGGER.warning(error.description, cause=error.cause)
 | 
				
			||||||
            raise RequestValidationError(
 | 
					            raise RequestValidationError(
 | 
				
			||||||
                bad_request_message(self.request, error.description, title=error.error)
 | 
					                bad_request_message(self.request, error.description, title=error.error)
 | 
				
			||||||
            ) from None
 | 
					            ) from None
 | 
				
			||||||
@ -630,6 +638,7 @@ class OAuthFulfillmentStage(StageView):
 | 
				
			|||||||
        if self.params.response_type in [
 | 
					        if self.params.response_type in [
 | 
				
			||||||
            ResponseTypes.ID_TOKEN_TOKEN,
 | 
					            ResponseTypes.ID_TOKEN_TOKEN,
 | 
				
			||||||
            ResponseTypes.CODE_ID_TOKEN_TOKEN,
 | 
					            ResponseTypes.CODE_ID_TOKEN_TOKEN,
 | 
				
			||||||
 | 
					            ResponseTypes.ID_TOKEN,
 | 
				
			||||||
            ResponseTypes.CODE_TOKEN,
 | 
					            ResponseTypes.CODE_TOKEN,
 | 
				
			||||||
        ]:
 | 
					        ]:
 | 
				
			||||||
            query_fragment["access_token"] = token.token
 | 
					            query_fragment["access_token"] = token.token
 | 
				
			||||||
 | 
				
			|||||||
@ -2,11 +2,13 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from asgiref.sync import async_to_sync
 | 
					from asgiref.sync import async_to_sync
 | 
				
			||||||
from channels.layers import get_channel_layer
 | 
					from channels.layers import get_channel_layer
 | 
				
			||||||
 | 
					from django.contrib.auth.signals import user_logged_out
 | 
				
			||||||
from django.core.cache import cache
 | 
					from django.core.cache import cache
 | 
				
			||||||
from django.db.models.signals import post_delete, post_save, pre_delete
 | 
					from django.db.models.signals import post_delete, post_save, pre_delete
 | 
				
			||||||
from django.dispatch import receiver
 | 
					from django.dispatch import receiver
 | 
				
			||||||
 | 
					from django.http import HttpRequest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import AuthenticatedSession
 | 
					from authentik.core.models import AuthenticatedSession, User
 | 
				
			||||||
from authentik.providers.rac.api.endpoints import user_endpoint_cache_key
 | 
					from authentik.providers.rac.api.endpoints import user_endpoint_cache_key
 | 
				
			||||||
from authentik.providers.rac.consumer_client import (
 | 
					from authentik.providers.rac.consumer_client import (
 | 
				
			||||||
    RAC_CLIENT_GROUP_SESSION,
 | 
					    RAC_CLIENT_GROUP_SESSION,
 | 
				
			||||||
@ -15,6 +17,21 @@ from authentik.providers.rac.consumer_client import (
 | 
				
			|||||||
from authentik.providers.rac.models import ConnectionToken, Endpoint
 | 
					from authentik.providers.rac.models import ConnectionToken, Endpoint
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@receiver(user_logged_out)
 | 
				
			||||||
 | 
					def user_logged_out_session(sender, request: HttpRequest, user: User, **_):
 | 
				
			||||||
 | 
					    """Disconnect any open RAC connections"""
 | 
				
			||||||
 | 
					    if not request.session or not request.session.session_key:
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					    layer = get_channel_layer()
 | 
				
			||||||
 | 
					    async_to_sync(layer.group_send)(
 | 
				
			||||||
 | 
					        RAC_CLIENT_GROUP_SESSION
 | 
				
			||||||
 | 
					        % {
 | 
				
			||||||
 | 
					            "session": request.session.session_key,
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        {"type": "event.disconnect", "reason": "session_logout"},
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@receiver(pre_delete, sender=AuthenticatedSession)
 | 
					@receiver(pre_delete, sender=AuthenticatedSession)
 | 
				
			||||||
def user_session_deleted(sender, instance: AuthenticatedSession, **_):
 | 
					def user_session_deleted(sender, instance: AuthenticatedSession, **_):
 | 
				
			||||||
    layer = get_channel_layer()
 | 
					    layer = get_channel_layer()
 | 
				
			||||||
 | 
				
			|||||||
@ -49,7 +49,6 @@ class TestEndpointsAPI(APITestCase):
 | 
				
			|||||||
        self.assertJSONEqual(
 | 
					        self.assertJSONEqual(
 | 
				
			||||||
            response.content.decode(),
 | 
					            response.content.decode(),
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                "autocomplete": {},
 | 
					 | 
				
			||||||
                "pagination": {
 | 
					                "pagination": {
 | 
				
			||||||
                    "next": 0,
 | 
					                    "next": 0,
 | 
				
			||||||
                    "previous": 0,
 | 
					                    "previous": 0,
 | 
				
			||||||
@ -102,7 +101,6 @@ class TestEndpointsAPI(APITestCase):
 | 
				
			|||||||
        self.assertJSONEqual(
 | 
					        self.assertJSONEqual(
 | 
				
			||||||
            response.content.decode(),
 | 
					            response.content.decode(),
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                "autocomplete": {},
 | 
					 | 
				
			||||||
                "pagination": {
 | 
					                "pagination": {
 | 
				
			||||||
                    "next": 0,
 | 
					                    "next": 0,
 | 
				
			||||||
                    "previous": 0,
 | 
					                    "previous": 0,
 | 
				
			||||||
 | 
				
			|||||||
@ -5,6 +5,7 @@ from itertools import batched
 | 
				
			|||||||
from django.db import transaction
 | 
					from django.db import transaction
 | 
				
			||||||
from pydantic import ValidationError
 | 
					from pydantic import ValidationError
 | 
				
			||||||
from pydanticscim.group import GroupMember
 | 
					from pydanticscim.group import GroupMember
 | 
				
			||||||
 | 
					from pydanticscim.responses import PatchOp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import Group
 | 
					from authentik.core.models import Group
 | 
				
			||||||
from authentik.lib.sync.mapper import PropertyMappingManager
 | 
					from authentik.lib.sync.mapper import PropertyMappingManager
 | 
				
			||||||
@ -19,12 +20,7 @@ from authentik.providers.scim.clients.base import SCIMClient
 | 
				
			|||||||
from authentik.providers.scim.clients.exceptions import (
 | 
					from authentik.providers.scim.clients.exceptions import (
 | 
				
			||||||
    SCIMRequestException,
 | 
					    SCIMRequestException,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from authentik.providers.scim.clients.schema import (
 | 
					from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest
 | 
				
			||||||
    SCIM_GROUP_SCHEMA,
 | 
					 | 
				
			||||||
    PatchOp,
 | 
					 | 
				
			||||||
    PatchOperation,
 | 
					 | 
				
			||||||
    PatchRequest,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
 | 
					from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
 | 
				
			||||||
from authentik.providers.scim.models import (
 | 
					from authentik.providers.scim.models import (
 | 
				
			||||||
    SCIMMapping,
 | 
					    SCIMMapping,
 | 
				
			||||||
 | 
				
			|||||||
@ -1,7 +1,5 @@
 | 
				
			|||||||
"""Custom SCIM schemas"""
 | 
					"""Custom SCIM schemas"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from enum import Enum
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from pydantic import Field
 | 
					from pydantic import Field
 | 
				
			||||||
from pydanticscim.group import Group as BaseGroup
 | 
					from pydanticscim.group import Group as BaseGroup
 | 
				
			||||||
from pydanticscim.responses import PatchOperation as BasePatchOperation
 | 
					from pydanticscim.responses import PatchOperation as BasePatchOperation
 | 
				
			||||||
@ -67,21 +65,6 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PatchOp(str, Enum):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    replace = "replace"
 | 
					 | 
				
			||||||
    remove = "remove"
 | 
					 | 
				
			||||||
    add = "add"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @classmethod
 | 
					 | 
				
			||||||
    def _missing_(cls, value):
 | 
					 | 
				
			||||||
        value = value.lower()
 | 
					 | 
				
			||||||
        for member in cls:
 | 
					 | 
				
			||||||
            if member.lower() == value:
 | 
					 | 
				
			||||||
                return member
 | 
					 | 
				
			||||||
        return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class PatchRequest(BasePatchRequest):
 | 
					class PatchRequest(BasePatchRequest):
 | 
				
			||||||
    """PatchRequest which correctly sets schemas"""
 | 
					    """PatchRequest which correctly sets schemas"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -91,7 +74,6 @@ class PatchRequest(BasePatchRequest):
 | 
				
			|||||||
class PatchOperation(BasePatchOperation):
 | 
					class PatchOperation(BasePatchOperation):
 | 
				
			||||||
    """PatchOperation with optional path"""
 | 
					    """PatchOperation with optional path"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    op: PatchOp
 | 
					 | 
				
			||||||
    path: str | None
 | 
					    path: str | None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -44,7 +44,6 @@ class TestRBACRoleAPI(APITestCase):
 | 
				
			|||||||
        self.assertJSONEqual(
 | 
					        self.assertJSONEqual(
 | 
				
			||||||
            res.content.decode(),
 | 
					            res.content.decode(),
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                "autocomplete": {},
 | 
					 | 
				
			||||||
                "pagination": {
 | 
					                "pagination": {
 | 
				
			||||||
                    "next": 0,
 | 
					                    "next": 0,
 | 
				
			||||||
                    "previous": 0,
 | 
					                    "previous": 0,
 | 
				
			||||||
 | 
				
			|||||||
@ -46,7 +46,6 @@ class TestRBACUserAPI(APITestCase):
 | 
				
			|||||||
        self.assertJSONEqual(
 | 
					        self.assertJSONEqual(
 | 
				
			||||||
            res.content.decode(),
 | 
					            res.content.decode(),
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                "autocomplete": {},
 | 
					 | 
				
			||||||
                "pagination": {
 | 
					                "pagination": {
 | 
				
			||||||
                    "next": 0,
 | 
					                    "next": 0,
 | 
				
			||||||
                    "previous": 0,
 | 
					                    "previous": 0,
 | 
				
			||||||
 | 
				
			|||||||
@ -38,7 +38,6 @@ class TestAPIPerms(APITestCase):
 | 
				
			|||||||
        self.assertJSONEqual(
 | 
					        self.assertJSONEqual(
 | 
				
			||||||
            res.content.decode(),
 | 
					            res.content.decode(),
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                "autocomplete": {},
 | 
					 | 
				
			||||||
                "pagination": {
 | 
					                "pagination": {
 | 
				
			||||||
                    "next": 0,
 | 
					                    "next": 0,
 | 
				
			||||||
                    "previous": 0,
 | 
					                    "previous": 0,
 | 
				
			||||||
@ -74,7 +73,6 @@ class TestAPIPerms(APITestCase):
 | 
				
			|||||||
        self.assertJSONEqual(
 | 
					        self.assertJSONEqual(
 | 
				
			||||||
            res.content.decode(),
 | 
					            res.content.decode(),
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                "autocomplete": {},
 | 
					 | 
				
			||||||
                "pagination": {
 | 
					                "pagination": {
 | 
				
			||||||
                    "next": 0,
 | 
					                    "next": 0,
 | 
				
			||||||
                    "previous": 0,
 | 
					                    "previous": 0,
 | 
				
			||||||
 | 
				
			|||||||
@ -9,14 +9,13 @@ https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import django
 | 
					import django
 | 
				
			||||||
from channels.routing import ProtocolTypeRouter, URLRouter
 | 
					from channels.routing import ProtocolTypeRouter, URLRouter
 | 
				
			||||||
 | 
					from defusedxml import defuse_stdlib
 | 
				
			||||||
from django.core.asgi import get_asgi_application
 | 
					from django.core.asgi import get_asgi_application
 | 
				
			||||||
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
 | 
					from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.root.setup import setup
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py
 | 
					# DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
setup()
 | 
					defuse_stdlib()
 | 
				
			||||||
django.setup()
 | 
					django.setup()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -27,7 +27,7 @@ from structlog.stdlib import get_logger
 | 
				
			|||||||
from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp
 | 
					from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik import get_full_version
 | 
					from authentik import get_full_version
 | 
				
			||||||
from authentik.lib.sentry import should_ignore_exception
 | 
					from authentik.lib.sentry import before_send
 | 
				
			||||||
from authentik.lib.utils.errors import exception_to_string
 | 
					from authentik.lib.utils.errors import exception_to_string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# set the default Django settings module for the 'celery' program.
 | 
					# set the default Django settings module for the 'celery' program.
 | 
				
			||||||
@ -81,7 +81,7 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    LOGGER.warning("Task failure", task_id=task_id.replace("-", ""), exc=exception)
 | 
					    LOGGER.warning("Task failure", task_id=task_id.replace("-", ""), exc=exception)
 | 
				
			||||||
    CTX_TASK_ID.set(...)
 | 
					    CTX_TASK_ID.set(...)
 | 
				
			||||||
    if not should_ignore_exception(exception):
 | 
					    if before_send({}, {"exc_info": (None, exception, None)}) is not None:
 | 
				
			||||||
        Event.new(
 | 
					        Event.new(
 | 
				
			||||||
            EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id
 | 
					            EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id
 | 
				
			||||||
        ).save()
 | 
					        ).save()
 | 
				
			||||||
 | 
				
			|||||||
@ -1,49 +1,13 @@
 | 
				
			|||||||
"""authentik database backend"""
 | 
					"""authentik database backend"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.core.checks import Warning
 | 
					 | 
				
			||||||
from django.db.backends.base.validation import BaseDatabaseValidation
 | 
					 | 
				
			||||||
from django_tenants.postgresql_backend.base import DatabaseWrapper as BaseDatabaseWrapper
 | 
					from django_tenants.postgresql_backend.base import DatabaseWrapper as BaseDatabaseWrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.lib.config import CONFIG
 | 
					from authentik.lib.config import CONFIG
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DatabaseValidation(BaseDatabaseValidation):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def check(self, **kwargs):
 | 
					 | 
				
			||||||
        return self._check_encoding()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _check_encoding(self):
 | 
					 | 
				
			||||||
        """Throw a warning when the server_encoding is not UTF-8 or
 | 
					 | 
				
			||||||
        server_encoding and client_encoding are mismatched"""
 | 
					 | 
				
			||||||
        messages = []
 | 
					 | 
				
			||||||
        with self.connection.cursor() as cursor:
 | 
					 | 
				
			||||||
            cursor.execute("SHOW server_encoding;")
 | 
					 | 
				
			||||||
            server_encoding = cursor.fetchone()[0]
 | 
					 | 
				
			||||||
            cursor.execute("SHOW client_encoding;")
 | 
					 | 
				
			||||||
            client_encoding = cursor.fetchone()[0]
 | 
					 | 
				
			||||||
            if server_encoding != client_encoding:
 | 
					 | 
				
			||||||
                messages.append(
 | 
					 | 
				
			||||||
                    Warning(
 | 
					 | 
				
			||||||
                        "PostgreSQL Server and Client encoding are mismatched: Server: "
 | 
					 | 
				
			||||||
                        f"{server_encoding}, Client: {client_encoding}",
 | 
					 | 
				
			||||||
                        id="ak.db.W001",
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
            if server_encoding != "UTF8":
 | 
					 | 
				
			||||||
                messages.append(
 | 
					 | 
				
			||||||
                    Warning(
 | 
					 | 
				
			||||||
                        f"PostgreSQL Server encoding is not UTF8: {server_encoding}",
 | 
					 | 
				
			||||||
                        id="ak.db.W002",
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
        return messages
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class DatabaseWrapper(BaseDatabaseWrapper):
 | 
					class DatabaseWrapper(BaseDatabaseWrapper):
 | 
				
			||||||
    """database backend which supports rotating credentials"""
 | 
					    """database backend which supports rotating credentials"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    validation_class = DatabaseValidation
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_connection_params(self):
 | 
					    def get_connection_params(self):
 | 
				
			||||||
        """Refresh DB credentials before getting connection params"""
 | 
					        """Refresh DB credentials before getting connection params"""
 | 
				
			||||||
        conn_params = super().get_connection_params()
 | 
					        conn_params = super().get_connection_params()
 | 
				
			||||||
 | 
				
			|||||||
@ -446,8 +446,6 @@ _DISALLOWED_ITEMS = [
 | 
				
			|||||||
    "MIDDLEWARE",
 | 
					    "MIDDLEWARE",
 | 
				
			||||||
    "AUTHENTICATION_BACKENDS",
 | 
					    "AUTHENTICATION_BACKENDS",
 | 
				
			||||||
    "CELERY",
 | 
					    "CELERY",
 | 
				
			||||||
    "SPECTACULAR_SETTINGS",
 | 
					 | 
				
			||||||
    "REST_FRAMEWORK",
 | 
					 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
SILENCED_SYSTEM_CHECKS = [
 | 
					SILENCED_SYSTEM_CHECKS = [
 | 
				
			||||||
@ -470,8 +468,6 @@ def _update_settings(app_path: str):
 | 
				
			|||||||
        TENANT_APPS.extend(getattr(settings_module, "TENANT_APPS", []))
 | 
					        TENANT_APPS.extend(getattr(settings_module, "TENANT_APPS", []))
 | 
				
			||||||
        MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", []))
 | 
					        MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", []))
 | 
				
			||||||
        AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", []))
 | 
					        AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", []))
 | 
				
			||||||
        SPECTACULAR_SETTINGS.update(getattr(settings_module, "SPECTACULAR_SETTINGS", {}))
 | 
					 | 
				
			||||||
        REST_FRAMEWORK.update(getattr(settings_module, "REST_FRAMEWORK", {}))
 | 
					 | 
				
			||||||
        CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {}))
 | 
					        CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {}))
 | 
				
			||||||
        for _attr in dir(settings_module):
 | 
					        for _attr in dir(settings_module):
 | 
				
			||||||
            if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS:
 | 
					            if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS:
 | 
				
			||||||
 | 
				
			|||||||
@ -1,26 +0,0 @@
 | 
				
			|||||||
import os
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from cryptography.hazmat.backends.openssl.backend import backend
 | 
					 | 
				
			||||||
from defusedxml import defuse_stdlib
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from authentik.lib.config import CONFIG
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def setup():
 | 
					 | 
				
			||||||
    warnings.filterwarnings("ignore", "SelectableGroups dict interface")
 | 
					 | 
				
			||||||
    warnings.filterwarnings(
 | 
					 | 
				
			||||||
        "ignore",
 | 
					 | 
				
			||||||
        "defusedxml.lxml is no longer supported and will be removed in a future release.",
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    warnings.filterwarnings(
 | 
					 | 
				
			||||||
        "ignore",
 | 
					 | 
				
			||||||
        "defusedxml.cElementTree is deprecated, import from defusedxml.ElementTree instead.",
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    defuse_stdlib()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if CONFIG.get_bool("compliance.fips.enabled", False):
 | 
					 | 
				
			||||||
        backend._enable_fips()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
 | 
					 | 
				
			||||||
@ -11,8 +11,6 @@ from django.contrib.contenttypes.models import ContentType
 | 
				
			|||||||
from django.test.runner import DiscoverRunner
 | 
					from django.test.runner import DiscoverRunner
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
 | 
					 | 
				
			||||||
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
 | 
					 | 
				
			||||||
from authentik.lib.config import CONFIG
 | 
					from authentik.lib.config import CONFIG
 | 
				
			||||||
from authentik.lib.sentry import sentry_init
 | 
					from authentik.lib.sentry import sentry_init
 | 
				
			||||||
from authentik.root.signals import post_startup, pre_startup, startup
 | 
					from authentik.root.signals import post_startup, pre_startup, startup
 | 
				
			||||||
@ -78,9 +76,6 @@ class PytestTestRunner(DiscoverRunner):  # pragma: no cover
 | 
				
			|||||||
        for key, value in test_config.items():
 | 
					        for key, value in test_config.items():
 | 
				
			||||||
            CONFIG.set(key, value)
 | 
					            CONFIG.set(key, value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ASN_CONTEXT_PROCESSOR.load()
 | 
					 | 
				
			||||||
        GEOIP_CONTEXT_PROCESSOR.load()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        sentry_init()
 | 
					        sentry_init()
 | 
				
			||||||
        self.logger.debug("Test environment configured")
 | 
					        self.logger.debug("Test environment configured")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -71,31 +71,37 @@ def ldap_sync_single(source_pk: str):
 | 
				
			|||||||
            return
 | 
					            return
 | 
				
			||||||
        # Delete all sync tasks from the cache
 | 
					        # Delete all sync tasks from the cache
 | 
				
			||||||
        DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete()
 | 
					        DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete()
 | 
				
			||||||
 | 
					        task = chain(
 | 
				
			||||||
        # The order of these operations needs to be preserved as each depends on the previous one(s)
 | 
					            # User and group sync can happen at once, they have no dependencies on each other
 | 
				
			||||||
        # 1. User and group sync can happen simultaneously
 | 
					            group(
 | 
				
			||||||
        # 2. Membership sync needs to run afterwards
 | 
					                ldap_sync_paginator(source, UserLDAPSynchronizer)
 | 
				
			||||||
        # 3. Finally, user and group deletions can happen simultaneously
 | 
					                + ldap_sync_paginator(source, GroupLDAPSynchronizer),
 | 
				
			||||||
        user_group_sync = ldap_sync_paginator(source, UserLDAPSynchronizer) + ldap_sync_paginator(
 | 
					            ),
 | 
				
			||||||
            source, GroupLDAPSynchronizer
 | 
					            # Membership sync needs to run afterwards
 | 
				
			||||||
 | 
					            group(
 | 
				
			||||||
 | 
					                ldap_sync_paginator(source, MembershipLDAPSynchronizer),
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            # Finally, deletions. What we'd really like to do here is something like
 | 
				
			||||||
 | 
					            # ```
 | 
				
			||||||
 | 
					            # user_identifiers = <ldap query>
 | 
				
			||||||
 | 
					            # User.objects.exclude(
 | 
				
			||||||
 | 
					            #     usersourceconnection__identifier__in=user_uniqueness_identifiers,
 | 
				
			||||||
 | 
					            # ).delete()
 | 
				
			||||||
 | 
					            # ```
 | 
				
			||||||
 | 
					            # This runs into performance issues in large installations. So instead we spread the
 | 
				
			||||||
 | 
					            # work out into three steps:
 | 
				
			||||||
 | 
					            # 1. Get every object from the LDAP source.
 | 
				
			||||||
 | 
					            # 2. Mark every object as "safe" in the database. This is quick, but any error could
 | 
				
			||||||
 | 
					            #    mean deleting users which should not be deleted, so we do it immediately, in
 | 
				
			||||||
 | 
					            #    large chunks, and only queue the deletion step afterwards.
 | 
				
			||||||
 | 
					            # 3. Delete every unmarked item. This is slow, so we spread it over many tasks in
 | 
				
			||||||
 | 
					            #    small chunks.
 | 
				
			||||||
 | 
					            group(
 | 
				
			||||||
 | 
					                ldap_sync_paginator(source, UserLDAPForwardDeletion)
 | 
				
			||||||
 | 
					                + ldap_sync_paginator(source, GroupLDAPForwardDeletion),
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        membership_sync = ldap_sync_paginator(source, MembershipLDAPSynchronizer)
 | 
					        task()
 | 
				
			||||||
        user_group_deletion = ldap_sync_paginator(
 | 
					 | 
				
			||||||
            source, UserLDAPForwardDeletion
 | 
					 | 
				
			||||||
        ) + ldap_sync_paginator(source, GroupLDAPForwardDeletion)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Celery is buggy with empty groups, so we are careful only to add non-empty groups.
 | 
					 | 
				
			||||||
        # See https://github.com/celery/celery/issues/9772
 | 
					 | 
				
			||||||
        task_groups = []
 | 
					 | 
				
			||||||
        if user_group_sync:
 | 
					 | 
				
			||||||
            task_groups.append(group(user_group_sync))
 | 
					 | 
				
			||||||
        if membership_sync:
 | 
					 | 
				
			||||||
            task_groups.append(group(membership_sync))
 | 
					 | 
				
			||||||
        if user_group_deletion:
 | 
					 | 
				
			||||||
            task_groups.append(group(user_group_deletion))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        all_tasks = chain(task_groups)
 | 
					 | 
				
			||||||
        all_tasks()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list:
 | 
					def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list:
 | 
				
			||||||
 | 
				
			|||||||
@ -1,277 +0,0 @@
 | 
				
			|||||||
"""Test SCIM Group"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from json import dumps
 | 
					 | 
				
			||||||
from uuid import uuid4
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from django.urls import reverse
 | 
					 | 
				
			||||||
from rest_framework.test import APITestCase
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from authentik.core.models import Group
 | 
					 | 
				
			||||||
from authentik.core.tests.utils import create_test_user
 | 
					 | 
				
			||||||
from authentik.events.models import Event, EventAction
 | 
					 | 
				
			||||||
from authentik.lib.generators import generate_id
 | 
					 | 
				
			||||||
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
 | 
					 | 
				
			||||||
from authentik.sources.scim.models import (
 | 
					 | 
				
			||||||
    SCIMSource,
 | 
					 | 
				
			||||||
    SCIMSourceGroup,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from authentik.sources.scim.views.v2.base import SCIM_CONTENT_TYPE
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class TestSCIMGroups(APITestCase):
 | 
					 | 
				
			||||||
    """Test SCIM Group view"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setUp(self) -> None:
 | 
					 | 
				
			||||||
        self.source = SCIMSource.objects.create(name=generate_id(), slug=generate_id())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_group_list(self):
 | 
					 | 
				
			||||||
        """Test full group list"""
 | 
					 | 
				
			||||||
        response = self.client.get(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_sources_scim:v2-groups",
 | 
					 | 
				
			||||||
                kwargs={
 | 
					 | 
				
			||||||
                    "source_slug": self.source.slug,
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_group_list_single(self):
 | 
					 | 
				
			||||||
        """Test full group list (single group)"""
 | 
					 | 
				
			||||||
        group = Group.objects.create(name=generate_id())
 | 
					 | 
				
			||||||
        user = create_test_user()
 | 
					 | 
				
			||||||
        group.users.add(user)
 | 
					 | 
				
			||||||
        SCIMSourceGroup.objects.create(
 | 
					 | 
				
			||||||
            source=self.source,
 | 
					 | 
				
			||||||
            group=group,
 | 
					 | 
				
			||||||
            id=str(uuid4()),
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        response = self.client.get(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_sources_scim:v2-groups",
 | 
					 | 
				
			||||||
                kwargs={
 | 
					 | 
				
			||||||
                    "source_slug": self.source.slug,
 | 
					 | 
				
			||||||
                    "group_id": str(group.pk),
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, second=200)
 | 
					 | 
				
			||||||
        SCIMGroupSchema.model_validate_json(response.content, strict=True)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_group_create(self):
 | 
					 | 
				
			||||||
        """Test group create"""
 | 
					 | 
				
			||||||
        ext_id = generate_id()
 | 
					 | 
				
			||||||
        response = self.client.post(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_sources_scim:v2-groups",
 | 
					 | 
				
			||||||
                kwargs={
 | 
					 | 
				
			||||||
                    "source_slug": self.source.slug,
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            data=dumps({"displayName": generate_id(), "externalId": ext_id}),
 | 
					 | 
				
			||||||
            content_type=SCIM_CONTENT_TYPE,
 | 
					 | 
				
			||||||
            HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 201)
 | 
					 | 
				
			||||||
        self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
 | 
					 | 
				
			||||||
        self.assertTrue(
 | 
					 | 
				
			||||||
            Event.objects.filter(
 | 
					 | 
				
			||||||
                action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
 | 
					 | 
				
			||||||
            ).exists()
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_group_create_members(self):
 | 
					 | 
				
			||||||
        """Test group create"""
 | 
					 | 
				
			||||||
        user = create_test_user()
 | 
					 | 
				
			||||||
        ext_id = generate_id()
 | 
					 | 
				
			||||||
        response = self.client.post(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_sources_scim:v2-groups",
 | 
					 | 
				
			||||||
                kwargs={
 | 
					 | 
				
			||||||
                    "source_slug": self.source.slug,
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            data=dumps(
 | 
					 | 
				
			||||||
                {
 | 
					 | 
				
			||||||
                    "displayName": generate_id(),
 | 
					 | 
				
			||||||
                    "externalId": ext_id,
 | 
					 | 
				
			||||||
                    "members": [{"value": str(user.uuid)}],
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            content_type=SCIM_CONTENT_TYPE,
 | 
					 | 
				
			||||||
            HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 201)
 | 
					 | 
				
			||||||
        self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
 | 
					 | 
				
			||||||
        self.assertTrue(
 | 
					 | 
				
			||||||
            Event.objects.filter(
 | 
					 | 
				
			||||||
                action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
 | 
					 | 
				
			||||||
            ).exists()
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_group_create_members_empty(self):
 | 
					 | 
				
			||||||
        """Test group create"""
 | 
					 | 
				
			||||||
        ext_id = generate_id()
 | 
					 | 
				
			||||||
        response = self.client.post(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_sources_scim:v2-groups",
 | 
					 | 
				
			||||||
                kwargs={
 | 
					 | 
				
			||||||
                    "source_slug": self.source.slug,
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            data=dumps({"displayName": generate_id(), "externalId": ext_id, "members": []}),
 | 
					 | 
				
			||||||
            content_type=SCIM_CONTENT_TYPE,
 | 
					 | 
				
			||||||
            HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 201)
 | 
					 | 
				
			||||||
        self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
 | 
					 | 
				
			||||||
        self.assertTrue(
 | 
					 | 
				
			||||||
            Event.objects.filter(
 | 
					 | 
				
			||||||
                action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
 | 
					 | 
				
			||||||
            ).exists()
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_group_create_duplicate(self):
 | 
					 | 
				
			||||||
        """Test group create (duplicate)"""
 | 
					 | 
				
			||||||
        group = Group.objects.create(name=generate_id())
 | 
					 | 
				
			||||||
        existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
 | 
					 | 
				
			||||||
        ext_id = generate_id()
 | 
					 | 
				
			||||||
        response = self.client.post(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_sources_scim:v2-groups",
 | 
					 | 
				
			||||||
                kwargs={
 | 
					 | 
				
			||||||
                    "source_slug": self.source.slug,
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            data=dumps(
 | 
					 | 
				
			||||||
                {"displayName": generate_id(), "externalId": ext_id, "id": str(existing.group.pk)}
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            content_type=SCIM_CONTENT_TYPE,
 | 
					 | 
				
			||||||
            HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 409)
 | 
					 | 
				
			||||||
        self.assertJSONEqual(
 | 
					 | 
				
			||||||
            response.content,
 | 
					 | 
				
			||||||
            {
 | 
					 | 
				
			||||||
                "detail": "Group with ID exists already.",
 | 
					 | 
				
			||||||
                "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
 | 
					 | 
				
			||||||
                "scimType": "uniqueness",
 | 
					 | 
				
			||||||
                "status": 409,
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_group_update(self):
 | 
					 | 
				
			||||||
        """Test group update"""
 | 
					 | 
				
			||||||
        group = Group.objects.create(name=generate_id())
 | 
					 | 
				
			||||||
        existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
 | 
					 | 
				
			||||||
        ext_id = generate_id()
 | 
					 | 
				
			||||||
        response = self.client.put(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_sources_scim:v2-groups",
 | 
					 | 
				
			||||||
                kwargs={"source_slug": self.source.slug, "group_id": group.pk},
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            data=dumps(
 | 
					 | 
				
			||||||
                {"displayName": generate_id(), "externalId": ext_id, "id": str(existing.pk)}
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            content_type=SCIM_CONTENT_TYPE,
 | 
					 | 
				
			||||||
            HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, second=200)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_group_update_non_existent(self):
 | 
					 | 
				
			||||||
        """Test group update"""
 | 
					 | 
				
			||||||
        ext_id = generate_id()
 | 
					 | 
				
			||||||
        response = self.client.put(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_sources_scim:v2-groups",
 | 
					 | 
				
			||||||
                kwargs={
 | 
					 | 
				
			||||||
                    "source_slug": self.source.slug,
 | 
					 | 
				
			||||||
                    "group_id": str(uuid4()),
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            data=dumps({"displayName": generate_id(), "externalId": ext_id, "id": ""}),
 | 
					 | 
				
			||||||
            content_type=SCIM_CONTENT_TYPE,
 | 
					 | 
				
			||||||
            HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, second=404)
 | 
					 | 
				
			||||||
        self.assertJSONEqual(
 | 
					 | 
				
			||||||
            response.content,
 | 
					 | 
				
			||||||
            {
 | 
					 | 
				
			||||||
                "detail": "Group not found.",
 | 
					 | 
				
			||||||
                "schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
 | 
					 | 
				
			||||||
                "status": 404,
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_group_patch_add(self):
 | 
					 | 
				
			||||||
        """Test group patch"""
 | 
					 | 
				
			||||||
        user = create_test_user()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        group = Group.objects.create(name=generate_id())
 | 
					 | 
				
			||||||
        SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
 | 
					 | 
				
			||||||
        response = self.client.patch(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_sources_scim:v2-groups",
 | 
					 | 
				
			||||||
                kwargs={"source_slug": self.source.slug, "group_id": group.pk},
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            data=dumps(
 | 
					 | 
				
			||||||
                {
 | 
					 | 
				
			||||||
                    "Operations": [
 | 
					 | 
				
			||||||
                        {
 | 
					 | 
				
			||||||
                            "op": "Add",
 | 
					 | 
				
			||||||
                            "path": "members",
 | 
					 | 
				
			||||||
                            "value": {"value": str(user.uuid)},
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                    ]
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            content_type=SCIM_CONTENT_TYPE,
 | 
					 | 
				
			||||||
            HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, second=200)
 | 
					 | 
				
			||||||
        self.assertTrue(group.users.filter(pk=user.pk).exists())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_group_patch_remove(self):
 | 
					 | 
				
			||||||
        """Test group patch"""
 | 
					 | 
				
			||||||
        user = create_test_user()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        group = Group.objects.create(name=generate_id())
 | 
					 | 
				
			||||||
        group.users.add(user)
 | 
					 | 
				
			||||||
        SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
 | 
					 | 
				
			||||||
        response = self.client.patch(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_sources_scim:v2-groups",
 | 
					 | 
				
			||||||
                kwargs={"source_slug": self.source.slug, "group_id": group.pk},
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            data=dumps(
 | 
					 | 
				
			||||||
                {
 | 
					 | 
				
			||||||
                    "Operations": [
 | 
					 | 
				
			||||||
                        {
 | 
					 | 
				
			||||||
                            "op": "remove",
 | 
					 | 
				
			||||||
                            "path": "members",
 | 
					 | 
				
			||||||
                            "value": {"value": str(user.uuid)},
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                    ]
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            content_type=SCIM_CONTENT_TYPE,
 | 
					 | 
				
			||||||
            HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, second=200)
 | 
					 | 
				
			||||||
        self.assertFalse(group.users.filter(pk=user.pk).exists())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_group_delete(self):
 | 
					 | 
				
			||||||
        """Test group delete"""
 | 
					 | 
				
			||||||
        group = Group.objects.create(name=generate_id())
 | 
					 | 
				
			||||||
        SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
 | 
					 | 
				
			||||||
        response = self.client.delete(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_sources_scim:v2-groups",
 | 
					 | 
				
			||||||
                kwargs={"source_slug": self.source.slug, "group_id": group.pk},
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            content_type=SCIM_CONTENT_TYPE,
 | 
					 | 
				
			||||||
            HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, second=204)
 | 
					 | 
				
			||||||
@ -177,51 +177,3 @@ class TestSCIMUsers(APITestCase):
 | 
				
			|||||||
            SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"],
 | 
					            SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"],
 | 
				
			||||||
            "0123456789",
 | 
					            "0123456789",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_user_update(self):
 | 
					 | 
				
			||||||
        """Test user update"""
 | 
					 | 
				
			||||||
        user = create_test_user()
 | 
					 | 
				
			||||||
        existing = SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
 | 
					 | 
				
			||||||
        ext_id = generate_id()
 | 
					 | 
				
			||||||
        response = self.client.put(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_sources_scim:v2-users",
 | 
					 | 
				
			||||||
                kwargs={
 | 
					 | 
				
			||||||
                    "source_slug": self.source.slug,
 | 
					 | 
				
			||||||
                    "user_id": str(user.uuid),
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            data=dumps(
 | 
					 | 
				
			||||||
                {
 | 
					 | 
				
			||||||
                    "id": str(existing.pk),
 | 
					 | 
				
			||||||
                    "userName": generate_id(),
 | 
					 | 
				
			||||||
                    "externalId": ext_id,
 | 
					 | 
				
			||||||
                    "emails": [
 | 
					 | 
				
			||||||
                        {
 | 
					 | 
				
			||||||
                            "primary": True,
 | 
					 | 
				
			||||||
                            "value": user.email,
 | 
					 | 
				
			||||||
                        }
 | 
					 | 
				
			||||||
                    ],
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            content_type=SCIM_CONTENT_TYPE,
 | 
					 | 
				
			||||||
            HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_user_delete(self):
 | 
					 | 
				
			||||||
        """Test user delete"""
 | 
					 | 
				
			||||||
        user = create_test_user()
 | 
					 | 
				
			||||||
        SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
 | 
					 | 
				
			||||||
        response = self.client.delete(
 | 
					 | 
				
			||||||
            reverse(
 | 
					 | 
				
			||||||
                "authentik_sources_scim:v2-users",
 | 
					 | 
				
			||||||
                kwargs={
 | 
					 | 
				
			||||||
                    "source_slug": self.source.slug,
 | 
					 | 
				
			||||||
                    "user_id": str(user.uuid),
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            content_type=SCIM_CONTENT_TYPE,
 | 
					 | 
				
			||||||
            HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 204)
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -8,7 +8,6 @@ from rest_framework.authentication import BaseAuthentication, get_authorization_
 | 
				
			|||||||
from rest_framework.request import Request
 | 
					from rest_framework.request import Request
 | 
				
			||||||
from rest_framework.views import APIView
 | 
					from rest_framework.views import APIView
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.middleware import CTX_AUTH_VIA
 | 
					 | 
				
			||||||
from authentik.core.models import Token, TokenIntents, User
 | 
					from authentik.core.models import Token, TokenIntents, User
 | 
				
			||||||
from authentik.sources.scim.models import SCIMSource
 | 
					from authentik.sources.scim.models import SCIMSource
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -27,7 +26,6 @@ class SCIMTokenAuth(BaseAuthentication):
 | 
				
			|||||||
        _username, _, password = b64decode(key.encode()).decode().partition(":")
 | 
					        _username, _, password = b64decode(key.encode()).decode().partition(":")
 | 
				
			||||||
        token = self.check_token(password, source_slug)
 | 
					        token = self.check_token(password, source_slug)
 | 
				
			||||||
        if token:
 | 
					        if token:
 | 
				
			||||||
            CTX_AUTH_VIA.set("scim_basic")
 | 
					 | 
				
			||||||
            return (token.user, token)
 | 
					            return (token.user, token)
 | 
				
			||||||
        return None
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -54,5 +52,4 @@ class SCIMTokenAuth(BaseAuthentication):
 | 
				
			|||||||
        token = self.check_token(key, source_slug)
 | 
					        token = self.check_token(key, source_slug)
 | 
				
			||||||
        if not token:
 | 
					        if not token:
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
        CTX_AUTH_VIA.set("scim_token")
 | 
					 | 
				
			||||||
        return (token.user, token)
 | 
					        return (token.user, token)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,11 +1,13 @@
 | 
				
			|||||||
"""SCIM Utils"""
 | 
					"""SCIM Utils"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Any
 | 
					from typing import Any
 | 
				
			||||||
 | 
					from urllib.parse import urlparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.conf import settings
 | 
					from django.conf import settings
 | 
				
			||||||
from django.core.paginator import Page, Paginator
 | 
					from django.core.paginator import Page, Paginator
 | 
				
			||||||
from django.db.models import Q, QuerySet
 | 
					from django.db.models import Q, QuerySet
 | 
				
			||||||
from django.http import HttpRequest
 | 
					from django.http import HttpRequest
 | 
				
			||||||
 | 
					from django.urls import resolve
 | 
				
			||||||
from rest_framework.parsers import JSONParser
 | 
					from rest_framework.parsers import JSONParser
 | 
				
			||||||
from rest_framework.permissions import IsAuthenticated
 | 
					from rest_framework.permissions import IsAuthenticated
 | 
				
			||||||
from rest_framework.renderers import JSONRenderer
 | 
					from rest_framework.renderers import JSONRenderer
 | 
				
			||||||
@ -44,7 +46,7 @@ class SCIMView(APIView):
 | 
				
			|||||||
    logger: BoundLogger
 | 
					    logger: BoundLogger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    permission_classes = [IsAuthenticated]
 | 
					    permission_classes = [IsAuthenticated]
 | 
				
			||||||
    parser_classes = [SCIMParser, JSONParser]
 | 
					    parser_classes = [SCIMParser]
 | 
				
			||||||
    renderer_classes = [SCIMRenderer]
 | 
					    renderer_classes = [SCIMRenderer]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def setup(self, request: HttpRequest, *args: Any, **kwargs: Any) -> None:
 | 
					    def setup(self, request: HttpRequest, *args: Any, **kwargs: Any) -> None:
 | 
				
			||||||
@ -54,6 +56,28 @@ class SCIMView(APIView):
 | 
				
			|||||||
    def get_authenticators(self):
 | 
					    def get_authenticators(self):
 | 
				
			||||||
        return [SCIMTokenAuth(self)]
 | 
					        return [SCIMTokenAuth(self)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def patch_resolve_value(self, raw_value: dict) -> User | Group | None:
 | 
				
			||||||
 | 
					        """Attempt to resolve a raw `value` attribute of a patch operation into
 | 
				
			||||||
 | 
					        a database model"""
 | 
				
			||||||
 | 
					        model = User
 | 
				
			||||||
 | 
					        query = {}
 | 
				
			||||||
 | 
					        if "$ref" in raw_value:
 | 
				
			||||||
 | 
					            url = urlparse(raw_value["$ref"])
 | 
				
			||||||
 | 
					            if match := resolve(url.path):
 | 
				
			||||||
 | 
					                if match.url_name == "v2-users":
 | 
				
			||||||
 | 
					                    model = User
 | 
				
			||||||
 | 
					                    query = {"pk": int(match.kwargs["user_id"])}
 | 
				
			||||||
 | 
					        elif "type" in raw_value:
 | 
				
			||||||
 | 
					            match raw_value["type"]:
 | 
				
			||||||
 | 
					                case "User":
 | 
				
			||||||
 | 
					                    model = User
 | 
				
			||||||
 | 
					                    query = {"pk": int(raw_value["value"])}
 | 
				
			||||||
 | 
					                case "Group":
 | 
				
			||||||
 | 
					                    model = Group
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					        return model.objects.filter(**query).first()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def filter_parse(self, request: Request):
 | 
					    def filter_parse(self, request: Request):
 | 
				
			||||||
        """Parse the path of a Patch Operation"""
 | 
					        """Parse the path of a Patch Operation"""
 | 
				
			||||||
        path = request.query_params.get("filter")
 | 
					        path = request.query_params.get("filter")
 | 
				
			||||||
 | 
				
			|||||||
@ -1,58 +0,0 @@
 | 
				
			|||||||
from enum import Enum
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from pydanticscim.responses import SCIMError as BaseSCIMError
 | 
					 | 
				
			||||||
from rest_framework.exceptions import ValidationError
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SCIMErrorTypes(Enum):
 | 
					 | 
				
			||||||
    invalid_filter = "invalidFilter"
 | 
					 | 
				
			||||||
    too_many = "tooMany"
 | 
					 | 
				
			||||||
    uniqueness = "uniqueness"
 | 
					 | 
				
			||||||
    mutability = "mutability"
 | 
					 | 
				
			||||||
    invalid_syntax = "invalidSyntax"
 | 
					 | 
				
			||||||
    invalid_path = "invalidPath"
 | 
					 | 
				
			||||||
    no_target = "noTarget"
 | 
					 | 
				
			||||||
    invalid_value = "invalidValue"
 | 
					 | 
				
			||||||
    invalid_vers = "invalidVers"
 | 
					 | 
				
			||||||
    sensitive = "sensitive"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SCIMError(BaseSCIMError):
 | 
					 | 
				
			||||||
    scimType: SCIMErrorTypes | None = None
 | 
					 | 
				
			||||||
    detail: str | None = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SCIMValidationError(ValidationError):
 | 
					 | 
				
			||||||
    status_code = 400
 | 
					 | 
				
			||||||
    default_detail = SCIMError(scimType=SCIMErrorTypes.invalid_syntax, status=400)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, detail: SCIMError | None):
 | 
					 | 
				
			||||||
        if detail is None:
 | 
					 | 
				
			||||||
            detail = self.default_detail
 | 
					 | 
				
			||||||
        detail.status = self.status_code
 | 
					 | 
				
			||||||
        self.detail = detail.model_dump(mode="json", exclude_none=True)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SCIMConflictError(SCIMValidationError):
 | 
					 | 
				
			||||||
    status_code = 409
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, detail: str):
 | 
					 | 
				
			||||||
        super().__init__(
 | 
					 | 
				
			||||||
            SCIMError(
 | 
					 | 
				
			||||||
                detail=detail,
 | 
					 | 
				
			||||||
                scimType=SCIMErrorTypes.uniqueness,
 | 
					 | 
				
			||||||
                status=self.status_code,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SCIMNotFoundError(SCIMValidationError):
 | 
					 | 
				
			||||||
    status_code = 404
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, detail: str):
 | 
					 | 
				
			||||||
        super().__init__(
 | 
					 | 
				
			||||||
            SCIMError(
 | 
					 | 
				
			||||||
                detail=detail,
 | 
					 | 
				
			||||||
                status=self.status_code,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
@ -4,25 +4,19 @@ from uuid import uuid4
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from django.db.models import Q
 | 
					from django.db.models import Q
 | 
				
			||||||
from django.db.transaction import atomic
 | 
					from django.db.transaction import atomic
 | 
				
			||||||
from django.http import QueryDict
 | 
					from django.http import Http404, QueryDict
 | 
				
			||||||
from django.urls import reverse
 | 
					from django.urls import reverse
 | 
				
			||||||
from pydantic import ValidationError as PydanticValidationError
 | 
					from pydantic import ValidationError as PydanticValidationError
 | 
				
			||||||
from pydanticscim.group import GroupMember
 | 
					from pydanticscim.group import GroupMember
 | 
				
			||||||
from rest_framework.exceptions import ValidationError
 | 
					from rest_framework.exceptions import ValidationError
 | 
				
			||||||
from rest_framework.request import Request
 | 
					from rest_framework.request import Request
 | 
				
			||||||
from rest_framework.response import Response
 | 
					from rest_framework.response import Response
 | 
				
			||||||
from scim2_filter_parser.attr_paths import AttrPath
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.core.models import Group, User
 | 
					from authentik.core.models import Group, User
 | 
				
			||||||
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOp, PatchOperation
 | 
					from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
 | 
				
			||||||
from authentik.providers.scim.clients.schema import Group as SCIMGroupModel
 | 
					from authentik.providers.scim.clients.schema import Group as SCIMGroupModel
 | 
				
			||||||
from authentik.sources.scim.models import SCIMSourceGroup
 | 
					from authentik.sources.scim.models import SCIMSourceGroup
 | 
				
			||||||
from authentik.sources.scim.views.v2.base import SCIMObjectView
 | 
					from authentik.sources.scim.views.v2.base import SCIMObjectView
 | 
				
			||||||
from authentik.sources.scim.views.v2.exceptions import (
 | 
					 | 
				
			||||||
    SCIMConflictError,
 | 
					 | 
				
			||||||
    SCIMNotFoundError,
 | 
					 | 
				
			||||||
    SCIMValidationError,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GroupsView(SCIMObjectView):
 | 
					class GroupsView(SCIMObjectView):
 | 
				
			||||||
@ -33,7 +27,7 @@ class GroupsView(SCIMObjectView):
 | 
				
			|||||||
    def group_to_scim(self, scim_group: SCIMSourceGroup) -> dict:
 | 
					    def group_to_scim(self, scim_group: SCIMSourceGroup) -> dict:
 | 
				
			||||||
        """Convert Group to SCIM data"""
 | 
					        """Convert Group to SCIM data"""
 | 
				
			||||||
        payload = SCIMGroupModel(
 | 
					        payload = SCIMGroupModel(
 | 
				
			||||||
            schemas=[SCIM_GROUP_SCHEMA],
 | 
					            schemas=[SCIM_USER_SCHEMA],
 | 
				
			||||||
            id=str(scim_group.group.pk),
 | 
					            id=str(scim_group.group.pk),
 | 
				
			||||||
            externalId=scim_group.id,
 | 
					            externalId=scim_group.id,
 | 
				
			||||||
            displayName=scim_group.group.name,
 | 
					            displayName=scim_group.group.name,
 | 
				
			||||||
@ -64,7 +58,7 @@ class GroupsView(SCIMObjectView):
 | 
				
			|||||||
        if group_id:
 | 
					        if group_id:
 | 
				
			||||||
            connection = base_query.filter(source=self.source, group__group_uuid=group_id).first()
 | 
					            connection = base_query.filter(source=self.source, group__group_uuid=group_id).first()
 | 
				
			||||||
            if not connection:
 | 
					            if not connection:
 | 
				
			||||||
                raise SCIMNotFoundError("Group not found.")
 | 
					                raise Http404
 | 
				
			||||||
            return Response(self.group_to_scim(connection))
 | 
					            return Response(self.group_to_scim(connection))
 | 
				
			||||||
        connections = (
 | 
					        connections = (
 | 
				
			||||||
            base_query.filter(source=self.source).order_by("pk").filter(self.filter_parse(request))
 | 
					            base_query.filter(source=self.source).order_by("pk").filter(self.filter_parse(request))
 | 
				
			||||||
@ -125,7 +119,7 @@ class GroupsView(SCIMObjectView):
 | 
				
			|||||||
        ).first()
 | 
					        ).first()
 | 
				
			||||||
        if connection:
 | 
					        if connection:
 | 
				
			||||||
            self.logger.debug("Found existing group")
 | 
					            self.logger.debug("Found existing group")
 | 
				
			||||||
            raise SCIMConflictError("Group with ID exists already.")
 | 
					            return Response(status=409)
 | 
				
			||||||
        connection = self.update_group(None, request.data)
 | 
					        connection = self.update_group(None, request.data)
 | 
				
			||||||
        return Response(self.group_to_scim(connection), status=201)
 | 
					        return Response(self.group_to_scim(connection), status=201)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -135,44 +129,10 @@ class GroupsView(SCIMObjectView):
 | 
				
			|||||||
            source=self.source, group__group_uuid=group_id
 | 
					            source=self.source, group__group_uuid=group_id
 | 
				
			||||||
        ).first()
 | 
					        ).first()
 | 
				
			||||||
        if not connection:
 | 
					        if not connection:
 | 
				
			||||||
            raise SCIMNotFoundError("Group not found.")
 | 
					            raise Http404
 | 
				
			||||||
        connection = self.update_group(connection, request.data)
 | 
					        connection = self.update_group(connection, request.data)
 | 
				
			||||||
        return Response(self.group_to_scim(connection), status=200)
 | 
					        return Response(self.group_to_scim(connection), status=200)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @atomic
 | 
					 | 
				
			||||||
    def patch(self, request: Request, group_id: str, **kwargs) -> Response:
 | 
					 | 
				
			||||||
        """Patch group handler"""
 | 
					 | 
				
			||||||
        connection = SCIMSourceGroup.objects.filter(
 | 
					 | 
				
			||||||
            source=self.source, group__group_uuid=group_id
 | 
					 | 
				
			||||||
        ).first()
 | 
					 | 
				
			||||||
        if not connection:
 | 
					 | 
				
			||||||
            raise SCIMNotFoundError("Group not found.")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for _op in request.data.get("Operations", []):
 | 
					 | 
				
			||||||
            operation = PatchOperation.model_validate(_op)
 | 
					 | 
				
			||||||
            if operation.op.lower() not in ["add", "remove", "replace"]:
 | 
					 | 
				
			||||||
                raise SCIMValidationError()
 | 
					 | 
				
			||||||
            attr_path = AttrPath(f'{operation.path} eq ""', {})
 | 
					 | 
				
			||||||
            if attr_path.first_path == ("members", None, None):
 | 
					 | 
				
			||||||
                # FIXME: this can probably be de-duplicated
 | 
					 | 
				
			||||||
                if operation.op == PatchOp.add:
 | 
					 | 
				
			||||||
                    if not isinstance(operation.value, list):
 | 
					 | 
				
			||||||
                        operation.value = [operation.value]
 | 
					 | 
				
			||||||
                    query = Q()
 | 
					 | 
				
			||||||
                    for member in operation.value:
 | 
					 | 
				
			||||||
                        query |= Q(uuid=member["value"])
 | 
					 | 
				
			||||||
                    if query:
 | 
					 | 
				
			||||||
                        connection.group.users.add(*User.objects.filter(query))
 | 
					 | 
				
			||||||
                elif operation.op == PatchOp.remove:
 | 
					 | 
				
			||||||
                    if not isinstance(operation.value, list):
 | 
					 | 
				
			||||||
                        operation.value = [operation.value]
 | 
					 | 
				
			||||||
                    query = Q()
 | 
					 | 
				
			||||||
                    for member in operation.value:
 | 
					 | 
				
			||||||
                        query |= Q(uuid=member["value"])
 | 
					 | 
				
			||||||
                    if query:
 | 
					 | 
				
			||||||
                        connection.group.users.remove(*User.objects.filter(query))
 | 
					 | 
				
			||||||
        return Response(self.group_to_scim(connection), status=200)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @atomic
 | 
					    @atomic
 | 
				
			||||||
    def delete(self, request: Request, group_id: str, **kwargs) -> Response:
 | 
					    def delete(self, request: Request, group_id: str, **kwargs) -> Response:
 | 
				
			||||||
        """Delete group handler"""
 | 
					        """Delete group handler"""
 | 
				
			||||||
@ -180,7 +140,7 @@ class GroupsView(SCIMObjectView):
 | 
				
			|||||||
            source=self.source, group__group_uuid=group_id
 | 
					            source=self.source, group__group_uuid=group_id
 | 
				
			||||||
        ).first()
 | 
					        ).first()
 | 
				
			||||||
        if not connection:
 | 
					        if not connection:
 | 
				
			||||||
            raise SCIMNotFoundError("Group not found.")
 | 
					            raise Http404
 | 
				
			||||||
        connection.group.delete()
 | 
					        connection.group.delete()
 | 
				
			||||||
        connection.delete()
 | 
					        connection.delete()
 | 
				
			||||||
        return Response(status=204)
 | 
					        return Response(status=204)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,11 +1,11 @@
 | 
				
			|||||||
"""SCIM Meta views"""
 | 
					"""SCIM Meta views"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.http import Http404
 | 
				
			||||||
from django.urls import reverse
 | 
					from django.urls import reverse
 | 
				
			||||||
from rest_framework.request import Request
 | 
					from rest_framework.request import Request
 | 
				
			||||||
from rest_framework.response import Response
 | 
					from rest_framework.response import Response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.sources.scim.views.v2.base import SCIMView
 | 
					from authentik.sources.scim.views.v2.base import SCIMView
 | 
				
			||||||
from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ResourceTypesView(SCIMView):
 | 
					class ResourceTypesView(SCIMView):
 | 
				
			||||||
@ -138,7 +138,7 @@ class ResourceTypesView(SCIMView):
 | 
				
			|||||||
            resource = [x for x in resource_types if x.get("id") == resource_type]
 | 
					            resource = [x for x in resource_types if x.get("id") == resource_type]
 | 
				
			||||||
            if resource:
 | 
					            if resource:
 | 
				
			||||||
                return Response(resource[0])
 | 
					                return Response(resource[0])
 | 
				
			||||||
            raise SCIMNotFoundError("Resource not found.")
 | 
					            raise Http404
 | 
				
			||||||
        return Response(
 | 
					        return Response(
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
 | 
					                "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
 | 
				
			||||||
 | 
				
			|||||||
@ -3,12 +3,12 @@
 | 
				
			|||||||
from json import loads
 | 
					from json import loads
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.conf import settings
 | 
					from django.conf import settings
 | 
				
			||||||
 | 
					from django.http import Http404
 | 
				
			||||||
from django.urls import reverse
 | 
					from django.urls import reverse
 | 
				
			||||||
from rest_framework.request import Request
 | 
					from rest_framework.request import Request
 | 
				
			||||||
from rest_framework.response import Response
 | 
					from rest_framework.response import Response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.sources.scim.views.v2.base import SCIMView
 | 
					from authentik.sources.scim.views.v2.base import SCIMView
 | 
				
			||||||
from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
with open(
 | 
					with open(
 | 
				
			||||||
    settings.BASE_DIR / "authentik" / "sources" / "scim" / "schemas" / "schema.json",
 | 
					    settings.BASE_DIR / "authentik" / "sources" / "scim" / "schemas" / "schema.json",
 | 
				
			||||||
@ -44,7 +44,7 @@ class SchemaView(SCIMView):
 | 
				
			|||||||
            schema = [x for x in schemas if x.get("id") == schema_uri]
 | 
					            schema = [x for x in schemas if x.get("id") == schema_uri]
 | 
				
			||||||
            if schema:
 | 
					            if schema:
 | 
				
			||||||
                return Response(schema[0])
 | 
					                return Response(schema[0])
 | 
				
			||||||
            raise SCIMNotFoundError("Schema not found.")
 | 
					            raise Http404
 | 
				
			||||||
        return Response(
 | 
					        return Response(
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
 | 
					                "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],
 | 
				
			||||||
 | 
				
			|||||||
@ -33,8 +33,6 @@ class ServiceProviderConfigView(SCIMView):
 | 
				
			|||||||
            {
 | 
					            {
 | 
				
			||||||
                "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"],
 | 
					                "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"],
 | 
				
			||||||
                "authenticationSchemes": auth_schemas,
 | 
					                "authenticationSchemes": auth_schemas,
 | 
				
			||||||
                # We only support patch for groups currently, so don't broadly advertise it.
 | 
					 | 
				
			||||||
                # Implementations that require Group patch will use it regardless of this flag.
 | 
					 | 
				
			||||||
                "patch": {"supported": False},
 | 
					                "patch": {"supported": False},
 | 
				
			||||||
                "bulk": {"supported": False, "maxOperations": 0, "maxPayloadSize": 0},
 | 
					                "bulk": {"supported": False, "maxOperations": 0, "maxPayloadSize": 0},
 | 
				
			||||||
                "filter": {
 | 
					                "filter": {
 | 
				
			||||||
 | 
				
			|||||||
@ -4,7 +4,7 @@ from uuid import uuid4
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from django.db.models import Q
 | 
					from django.db.models import Q
 | 
				
			||||||
from django.db.transaction import atomic
 | 
					from django.db.transaction import atomic
 | 
				
			||||||
from django.http import QueryDict
 | 
					from django.http import Http404, QueryDict
 | 
				
			||||||
from django.urls import reverse
 | 
					from django.urls import reverse
 | 
				
			||||||
from pydanticscim.user import Email, EmailKind, Name
 | 
					from pydanticscim.user import Email, EmailKind, Name
 | 
				
			||||||
from rest_framework.exceptions import ValidationError
 | 
					from rest_framework.exceptions import ValidationError
 | 
				
			||||||
@ -16,7 +16,6 @@ from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
 | 
				
			|||||||
from authentik.providers.scim.clients.schema import User as SCIMUserModel
 | 
					from authentik.providers.scim.clients.schema import User as SCIMUserModel
 | 
				
			||||||
from authentik.sources.scim.models import SCIMSourceUser
 | 
					from authentik.sources.scim.models import SCIMSourceUser
 | 
				
			||||||
from authentik.sources.scim.views.v2.base import SCIMObjectView
 | 
					from authentik.sources.scim.views.v2.base import SCIMObjectView
 | 
				
			||||||
from authentik.sources.scim.views.v2.exceptions import SCIMConflictError, SCIMNotFoundError
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class UsersView(SCIMObjectView):
 | 
					class UsersView(SCIMObjectView):
 | 
				
			||||||
@ -70,7 +69,7 @@ class UsersView(SCIMObjectView):
 | 
				
			|||||||
                .first()
 | 
					                .first()
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            if not connection:
 | 
					            if not connection:
 | 
				
			||||||
                raise SCIMNotFoundError("User not found.")
 | 
					                raise Http404
 | 
				
			||||||
            return Response(self.user_to_scim(connection))
 | 
					            return Response(self.user_to_scim(connection))
 | 
				
			||||||
        connections = (
 | 
					        connections = (
 | 
				
			||||||
            SCIMSourceUser.objects.filter(source=self.source).select_related("user").order_by("pk")
 | 
					            SCIMSourceUser.objects.filter(source=self.source).select_related("user").order_by("pk")
 | 
				
			||||||
@ -123,7 +122,7 @@ class UsersView(SCIMObjectView):
 | 
				
			|||||||
        ).first()
 | 
					        ).first()
 | 
				
			||||||
        if connection:
 | 
					        if connection:
 | 
				
			||||||
            self.logger.debug("Found existing user")
 | 
					            self.logger.debug("Found existing user")
 | 
				
			||||||
            raise SCIMConflictError("Group with ID exists already.")
 | 
					            return Response(status=409)
 | 
				
			||||||
        connection = self.update_user(None, request.data)
 | 
					        connection = self.update_user(None, request.data)
 | 
				
			||||||
        return Response(self.user_to_scim(connection), status=201)
 | 
					        return Response(self.user_to_scim(connection), status=201)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -131,7 +130,7 @@ class UsersView(SCIMObjectView):
 | 
				
			|||||||
        """Update user handler"""
 | 
					        """Update user handler"""
 | 
				
			||||||
        connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
 | 
					        connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
 | 
				
			||||||
        if not connection:
 | 
					        if not connection:
 | 
				
			||||||
            raise SCIMNotFoundError("User not found.")
 | 
					            raise Http404
 | 
				
			||||||
        self.update_user(connection, request.data)
 | 
					        self.update_user(connection, request.data)
 | 
				
			||||||
        return Response(self.user_to_scim(connection), status=200)
 | 
					        return Response(self.user_to_scim(connection), status=200)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -140,7 +139,7 @@ class UsersView(SCIMObjectView):
 | 
				
			|||||||
        """Delete user handler"""
 | 
					        """Delete user handler"""
 | 
				
			||||||
        connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
 | 
					        connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
 | 
				
			||||||
        if not connection:
 | 
					        if not connection:
 | 
				
			||||||
            raise SCIMNotFoundError("User not found.")
 | 
					            raise Http404
 | 
				
			||||||
        connection.user.delete()
 | 
					        connection.user.delete()
 | 
				
			||||||
        connection.delete()
 | 
					        connection.delete()
 | 
				
			||||||
        return Response(status=204)
 | 
					        return Response(status=204)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,7 +1,6 @@
 | 
				
			|||||||
"""Validation stage challenge checking"""
 | 
					"""Validation stage challenge checking"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from json import loads
 | 
					from json import loads
 | 
				
			||||||
from typing import TYPE_CHECKING
 | 
					 | 
				
			||||||
from urllib.parse import urlencode
 | 
					from urllib.parse import urlencode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.http import HttpRequest
 | 
					from django.http import HttpRequest
 | 
				
			||||||
@ -37,12 +36,10 @@ from authentik.stages.authenticator_email.models import EmailDevice
 | 
				
			|||||||
from authentik.stages.authenticator_sms.models import SMSDevice
 | 
					from authentik.stages.authenticator_sms.models import SMSDevice
 | 
				
			||||||
from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses
 | 
					from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses
 | 
				
			||||||
from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice
 | 
					from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice
 | 
				
			||||||
from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE
 | 
					from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE
 | 
				
			||||||
from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id
 | 
					from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOGGER = get_logger()
 | 
					LOGGER = get_logger()
 | 
				
			||||||
if TYPE_CHECKING:
 | 
					 | 
				
			||||||
    from authentik.stages.authenticator_validate.stage import AuthenticatorValidateStageView
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DeviceChallenge(PassiveSerializer):
 | 
					class DeviceChallenge(PassiveSerializer):
 | 
				
			||||||
@ -55,11 +52,11 @@ class DeviceChallenge(PassiveSerializer):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_challenge_for_device(
 | 
					def get_challenge_for_device(
 | 
				
			||||||
    stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage, device: Device
 | 
					    request: HttpRequest, stage: AuthenticatorValidateStage, device: Device
 | 
				
			||||||
) -> dict:
 | 
					) -> dict:
 | 
				
			||||||
    """Generate challenge for a single device"""
 | 
					    """Generate challenge for a single device"""
 | 
				
			||||||
    if isinstance(device, WebAuthnDevice):
 | 
					    if isinstance(device, WebAuthnDevice):
 | 
				
			||||||
        return get_webauthn_challenge(stage_view, stage, device)
 | 
					        return get_webauthn_challenge(request, stage, device)
 | 
				
			||||||
    if isinstance(device, EmailDevice):
 | 
					    if isinstance(device, EmailDevice):
 | 
				
			||||||
        return {"email": mask_email(device.email)}
 | 
					        return {"email": mask_email(device.email)}
 | 
				
			||||||
    # Code-based challenges have no hints
 | 
					    # Code-based challenges have no hints
 | 
				
			||||||
@ -67,30 +64,26 @@ def get_challenge_for_device(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_webauthn_challenge_without_user(
 | 
					def get_webauthn_challenge_without_user(
 | 
				
			||||||
    stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage
 | 
					    request: HttpRequest, stage: AuthenticatorValidateStage
 | 
				
			||||||
) -> dict:
 | 
					) -> dict:
 | 
				
			||||||
    """Same as `get_webauthn_challenge`, but allows any client device. We can then later check
 | 
					    """Same as `get_webauthn_challenge`, but allows any client device. We can then later check
 | 
				
			||||||
    who the device belongs to."""
 | 
					    who the device belongs to."""
 | 
				
			||||||
    stage_view.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None)
 | 
					    request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
 | 
				
			||||||
    authentication_options = generate_authentication_options(
 | 
					    authentication_options = generate_authentication_options(
 | 
				
			||||||
        rp_id=get_rp_id(stage_view.request),
 | 
					        rp_id=get_rp_id(request),
 | 
				
			||||||
        allow_credentials=[],
 | 
					        allow_credentials=[],
 | 
				
			||||||
        user_verification=UserVerificationRequirement(stage.webauthn_user_verification),
 | 
					        user_verification=UserVerificationRequirement(stage.webauthn_user_verification),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    stage_view.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = (
 | 
					    request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge
 | 
				
			||||||
        authentication_options.challenge
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return loads(options_to_json(authentication_options))
 | 
					    return loads(options_to_json(authentication_options))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_webauthn_challenge(
 | 
					def get_webauthn_challenge(
 | 
				
			||||||
    stage_view: "AuthenticatorValidateStageView",
 | 
					    request: HttpRequest, stage: AuthenticatorValidateStage, device: WebAuthnDevice | None = None
 | 
				
			||||||
    stage: AuthenticatorValidateStage,
 | 
					 | 
				
			||||||
    device: WebAuthnDevice | None = None,
 | 
					 | 
				
			||||||
) -> dict:
 | 
					) -> dict:
 | 
				
			||||||
    """Send the client a challenge that we'll check later"""
 | 
					    """Send the client a challenge that we'll check later"""
 | 
				
			||||||
    stage_view.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None)
 | 
					    request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    allowed_credentials = []
 | 
					    allowed_credentials = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -101,14 +94,12 @@ def get_webauthn_challenge(
 | 
				
			|||||||
            allowed_credentials.append(user_device.descriptor)
 | 
					            allowed_credentials.append(user_device.descriptor)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    authentication_options = generate_authentication_options(
 | 
					    authentication_options = generate_authentication_options(
 | 
				
			||||||
        rp_id=get_rp_id(stage_view.request),
 | 
					        rp_id=get_rp_id(request),
 | 
				
			||||||
        allow_credentials=allowed_credentials,
 | 
					        allow_credentials=allowed_credentials,
 | 
				
			||||||
        user_verification=UserVerificationRequirement(stage.webauthn_user_verification),
 | 
					        user_verification=UserVerificationRequirement(stage.webauthn_user_verification),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    stage_view.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = (
 | 
					    request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge
 | 
				
			||||||
        authentication_options.challenge
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return loads(options_to_json(authentication_options))
 | 
					    return loads(options_to_json(authentication_options))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -155,7 +146,7 @@ def validate_challenge_code(code: str, stage_view: StageView, user: User) -> Dev
 | 
				
			|||||||
def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -> Device:
 | 
					def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -> Device:
 | 
				
			||||||
    """Validate WebAuthn Challenge"""
 | 
					    """Validate WebAuthn Challenge"""
 | 
				
			||||||
    request = stage_view.request
 | 
					    request = stage_view.request
 | 
				
			||||||
    challenge = stage_view.executor.plan.context.get(PLAN_CONTEXT_WEBAUTHN_CHALLENGE)
 | 
					    challenge = request.session.get(SESSION_KEY_WEBAUTHN_CHALLENGE)
 | 
				
			||||||
    stage: AuthenticatorValidateStage = stage_view.executor.current_stage
 | 
					    stage: AuthenticatorValidateStage = stage_view.executor.current_stage
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        credential = parse_authentication_credential_json(data)
 | 
					        credential = parse_authentication_credential_json(data)
 | 
				
			||||||
 | 
				
			|||||||
@ -224,7 +224,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
 | 
				
			|||||||
                data={
 | 
					                data={
 | 
				
			||||||
                    "device_class": device_class,
 | 
					                    "device_class": device_class,
 | 
				
			||||||
                    "device_uid": device.pk,
 | 
					                    "device_uid": device.pk,
 | 
				
			||||||
                    "challenge": get_challenge_for_device(self, stage, device),
 | 
					                    "challenge": get_challenge_for_device(self.request, stage, device),
 | 
				
			||||||
                    "last_used": device.last_used,
 | 
					                    "last_used": device.last_used,
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
@ -243,7 +243,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
 | 
				
			|||||||
                "device_class": DeviceClasses.WEBAUTHN,
 | 
					                "device_class": DeviceClasses.WEBAUTHN,
 | 
				
			||||||
                "device_uid": -1,
 | 
					                "device_uid": -1,
 | 
				
			||||||
                "challenge": get_webauthn_challenge_without_user(
 | 
					                "challenge": get_webauthn_challenge_without_user(
 | 
				
			||||||
                    self,
 | 
					                    self.request,
 | 
				
			||||||
                    self.executor.current_stage,
 | 
					                    self.executor.current_stage,
 | 
				
			||||||
                ),
 | 
					                ),
 | 
				
			||||||
                "last_used": None,
 | 
					                "last_used": None,
 | 
				
			||||||
 | 
				
			|||||||
@ -31,7 +31,7 @@ from authentik.stages.authenticator_webauthn.models import (
 | 
				
			|||||||
    WebAuthnDevice,
 | 
					    WebAuthnDevice,
 | 
				
			||||||
    WebAuthnDeviceType,
 | 
					    WebAuthnDeviceType,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE
 | 
					from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE
 | 
				
			||||||
from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import
 | 
					from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import
 | 
				
			||||||
from authentik.stages.identification.models import IdentificationStage, UserFields
 | 
					from authentik.stages.identification.models import IdentificationStage, UserFields
 | 
				
			||||||
from authentik.stages.user_login.models import UserLoginStage
 | 
					from authentik.stages.user_login.models import UserLoginStage
 | 
				
			||||||
@ -103,11 +103,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
 | 
				
			|||||||
            device_classes=[DeviceClasses.WEBAUTHN],
 | 
					            device_classes=[DeviceClasses.WEBAUTHN],
 | 
				
			||||||
            webauthn_user_verification=UserVerification.PREFERRED,
 | 
					            webauthn_user_verification=UserVerification.PREFERRED,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        plan = FlowPlan("")
 | 
					        challenge = get_challenge_for_device(request, stage, webauthn_device)
 | 
				
			||||||
        stage_view = AuthenticatorValidateStageView(
 | 
					 | 
				
			||||||
            FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        challenge = get_challenge_for_device(stage_view, stage, webauthn_device)
 | 
					 | 
				
			||||||
        del challenge["challenge"]
 | 
					        del challenge["challenge"]
 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            challenge,
 | 
					            challenge,
 | 
				
			||||||
@ -126,9 +122,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        with self.assertRaises(ValidationError):
 | 
					        with self.assertRaises(ValidationError):
 | 
				
			||||||
            validate_challenge_webauthn(
 | 
					            validate_challenge_webauthn(
 | 
				
			||||||
                {},
 | 
					                {}, StageView(FlowExecutorView(current_stage=stage), request=request), self.user
 | 
				
			||||||
                StageView(FlowExecutorView(current_stage=stage, plan=plan), request=request),
 | 
					 | 
				
			||||||
                self.user,
 | 
					 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_device_challenge_webauthn_restricted(self):
 | 
					    def test_device_challenge_webauthn_restricted(self):
 | 
				
			||||||
@ -199,35 +193,22 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
 | 
				
			|||||||
            sign_count=0,
 | 
					            sign_count=0,
 | 
				
			||||||
            rp_id=generate_id(),
 | 
					            rp_id=generate_id(),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        plan = FlowPlan("")
 | 
					        challenge = get_challenge_for_device(request, stage, webauthn_device)
 | 
				
			||||||
        plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
 | 
					        webauthn_challenge = request.session[SESSION_KEY_WEBAUTHN_CHALLENGE]
 | 
				
			||||||
            "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        stage_view = AuthenticatorValidateStageView(
 | 
					 | 
				
			||||||
            FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        challenge = get_challenge_for_device(stage_view, stage, webauthn_device)
 | 
					 | 
				
			||||||
        self.assertEqual(
 | 
					        self.assertEqual(
 | 
				
			||||||
            challenge["allowCredentials"],
 | 
					            challenge,
 | 
				
			||||||
            [
 | 
					            {
 | 
				
			||||||
                {
 | 
					                "allowCredentials": [
 | 
				
			||||||
                    "id": "QKZ97ASJAOIDyipAs6mKUxDUZgDrWrbAsUb5leL7-oU",
 | 
					                    {
 | 
				
			||||||
                    "type": "public-key",
 | 
					                        "id": "QKZ97ASJAOIDyipAs6mKUxDUZgDrWrbAsUb5leL7-oU",
 | 
				
			||||||
                }
 | 
					                        "type": "public-key",
 | 
				
			||||||
            ],
 | 
					                    }
 | 
				
			||||||
        )
 | 
					                ],
 | 
				
			||||||
        self.assertIsNotNone(challenge["challenge"])
 | 
					                "challenge": bytes_to_base64url(webauthn_challenge),
 | 
				
			||||||
        self.assertEqual(
 | 
					                "rpId": "testserver",
 | 
				
			||||||
            challenge["rpId"],
 | 
					                "timeout": 60000,
 | 
				
			||||||
            "testserver",
 | 
					                "userVerification": "preferred",
 | 
				
			||||||
        )
 | 
					            },
 | 
				
			||||||
        self.assertEqual(
 | 
					 | 
				
			||||||
            challenge["timeout"],
 | 
					 | 
				
			||||||
            60000,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(
 | 
					 | 
				
			||||||
            challenge["userVerification"],
 | 
					 | 
				
			||||||
            "preferred",
 | 
					 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_get_challenge_userless(self):
 | 
					    def test_get_challenge_userless(self):
 | 
				
			||||||
@ -247,16 +228,18 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
 | 
				
			|||||||
            sign_count=0,
 | 
					            sign_count=0,
 | 
				
			||||||
            rp_id=generate_id(),
 | 
					            rp_id=generate_id(),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        plan = FlowPlan("")
 | 
					        challenge = get_webauthn_challenge_without_user(request, stage)
 | 
				
			||||||
        stage_view = AuthenticatorValidateStageView(
 | 
					        webauthn_challenge = request.session[SESSION_KEY_WEBAUTHN_CHALLENGE]
 | 
				
			||||||
            FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            challenge,
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "allowCredentials": [],
 | 
				
			||||||
 | 
					                "challenge": bytes_to_base64url(webauthn_challenge),
 | 
				
			||||||
 | 
					                "rpId": "testserver",
 | 
				
			||||||
 | 
					                "timeout": 60000,
 | 
				
			||||||
 | 
					                "userVerification": "preferred",
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        challenge = get_webauthn_challenge_without_user(stage_view, stage)
 | 
					 | 
				
			||||||
        self.assertEqual(challenge["allowCredentials"], [])
 | 
					 | 
				
			||||||
        self.assertIsNotNone(challenge["challenge"])
 | 
					 | 
				
			||||||
        self.assertEqual(challenge["rpId"], "testserver")
 | 
					 | 
				
			||||||
        self.assertEqual(challenge["timeout"], 60000)
 | 
					 | 
				
			||||||
        self.assertEqual(challenge["userVerification"], "preferred")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_validate_challenge_unrestricted(self):
 | 
					    def test_validate_challenge_unrestricted(self):
 | 
				
			||||||
        """Test webauthn authentication (unrestricted webauthn device)"""
 | 
					        """Test webauthn authentication (unrestricted webauthn device)"""
 | 
				
			||||||
@ -292,10 +275,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
 | 
				
			|||||||
                "last_used": None,
 | 
					                "last_used": None,
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
        plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
 | 
					        session[SESSION_KEY_PLAN] = plan
 | 
				
			||||||
 | 
					        session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
 | 
				
			||||||
            "aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ"
 | 
					            "aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        session[SESSION_KEY_PLAN] = plan
 | 
					 | 
				
			||||||
        session.save()
 | 
					        session.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        response = self.client.post(
 | 
					        response = self.client.post(
 | 
				
			||||||
@ -369,10 +352,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
 | 
				
			|||||||
                "last_used": None,
 | 
					                "last_used": None,
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
        plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
 | 
					        session[SESSION_KEY_PLAN] = plan
 | 
				
			||||||
 | 
					        session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
 | 
				
			||||||
            "aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ"
 | 
					            "aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        session[SESSION_KEY_PLAN] = plan
 | 
					 | 
				
			||||||
        session.save()
 | 
					        session.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        response = self.client.post(
 | 
					        response = self.client.post(
 | 
				
			||||||
@ -450,10 +433,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
 | 
				
			|||||||
                "last_used": None,
 | 
					                "last_used": None,
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
        plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
 | 
					        session[SESSION_KEY_PLAN] = plan
 | 
				
			||||||
 | 
					        session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
 | 
				
			||||||
            "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
 | 
					            "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        session[SESSION_KEY_PLAN] = plan
 | 
					 | 
				
			||||||
        session.save()
 | 
					        session.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        response = self.client.post(
 | 
					        response = self.client.post(
 | 
				
			||||||
@ -513,14 +496,17 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
 | 
				
			|||||||
            not_configured_action=NotConfiguredAction.CONFIGURE,
 | 
					            not_configured_action=NotConfiguredAction.CONFIGURE,
 | 
				
			||||||
            device_classes=[DeviceClasses.WEBAUTHN],
 | 
					            device_classes=[DeviceClasses.WEBAUTHN],
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        plan = FlowPlan(flow.pk.hex)
 | 
					        stage_view = AuthenticatorValidateStageView(
 | 
				
			||||||
        plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
 | 
					            FlowExecutorView(flow=flow, current_stage=stage), request=request
 | 
				
			||||||
            "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
 | 
					 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        request = get_request("/")
 | 
					        request = get_request("/")
 | 
				
			||||||
 | 
					        request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
 | 
				
			||||||
 | 
					            "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        request.session.save()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        stage_view = AuthenticatorValidateStageView(
 | 
					        stage_view = AuthenticatorValidateStageView(
 | 
				
			||||||
            FlowExecutorView(flow=flow, current_stage=stage, plan=plan), request=request
 | 
					            FlowExecutorView(flow=flow, current_stage=stage), request=request
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        request.META["SERVER_NAME"] = "localhost"
 | 
					        request.META["SERVER_NAME"] = "localhost"
 | 
				
			||||||
        request.META["SERVER_PORT"] = "9000"
 | 
					        request.META["SERVER_PORT"] = "9000"
 | 
				
			||||||
 | 
				
			|||||||
@ -25,7 +25,6 @@ class AuthenticatorWebAuthnStageSerializer(StageSerializer):
 | 
				
			|||||||
            "resident_key_requirement",
 | 
					            "resident_key_requirement",
 | 
				
			||||||
            "device_type_restrictions",
 | 
					            "device_type_restrictions",
 | 
				
			||||||
            "device_type_restrictions_obj",
 | 
					            "device_type_restrictions_obj",
 | 
				
			||||||
            "max_attempts",
 | 
					 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -1,21 +0,0 @@
 | 
				
			|||||||
# Generated by Django 5.1.11 on 2025-06-13 22:41
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from django.db import migrations, models
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Migration(migrations.Migration):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    dependencies = [
 | 
					 | 
				
			||||||
        (
 | 
					 | 
				
			||||||
            "authentik_stages_authenticator_webauthn",
 | 
					 | 
				
			||||||
            "0012_webauthndevice_created_webauthndevice_last_updated_and_more",
 | 
					 | 
				
			||||||
        ),
 | 
					 | 
				
			||||||
    ]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    operations = [
 | 
					 | 
				
			||||||
        migrations.AddField(
 | 
					 | 
				
			||||||
            model_name="authenticatorwebauthnstage",
 | 
					 | 
				
			||||||
            name="max_attempts",
 | 
					 | 
				
			||||||
            field=models.PositiveIntegerField(default=0),
 | 
					 | 
				
			||||||
        ),
 | 
					 | 
				
			||||||
    ]
 | 
					 | 
				
			||||||
@ -84,8 +84,6 @@ class AuthenticatorWebAuthnStage(ConfigurableStage, FriendlyNamedStage, Stage):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    device_type_restrictions = models.ManyToManyField("WebAuthnDeviceType", blank=True)
 | 
					    device_type_restrictions = models.ManyToManyField("WebAuthnDeviceType", blank=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    max_attempts = models.PositiveIntegerField(default=0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def serializer(self) -> type[BaseSerializer]:
 | 
					    def serializer(self) -> type[BaseSerializer]:
 | 
				
			||||||
        from authentik.stages.authenticator_webauthn.api.stages import (
 | 
					        from authentik.stages.authenticator_webauthn.api.stages import (
 | 
				
			||||||
 | 
				
			|||||||
@ -5,13 +5,12 @@ from uuid import UUID
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from django.http import HttpRequest, HttpResponse
 | 
					from django.http import HttpRequest, HttpResponse
 | 
				
			||||||
from django.http.request import QueryDict
 | 
					from django.http.request import QueryDict
 | 
				
			||||||
from django.utils.translation import gettext as __
 | 
					 | 
				
			||||||
from django.utils.translation import gettext_lazy as _
 | 
					from django.utils.translation import gettext_lazy as _
 | 
				
			||||||
from rest_framework.fields import CharField
 | 
					from rest_framework.fields import CharField
 | 
				
			||||||
from rest_framework.serializers import ValidationError
 | 
					from rest_framework.serializers import ValidationError
 | 
				
			||||||
from webauthn import options_to_json
 | 
					from webauthn import options_to_json
 | 
				
			||||||
from webauthn.helpers.bytes_to_base64url import bytes_to_base64url
 | 
					from webauthn.helpers.bytes_to_base64url import bytes_to_base64url
 | 
				
			||||||
from webauthn.helpers.exceptions import WebAuthnException
 | 
					from webauthn.helpers.exceptions import InvalidRegistrationResponse
 | 
				
			||||||
from webauthn.helpers.structs import (
 | 
					from webauthn.helpers.structs import (
 | 
				
			||||||
    AttestationConveyancePreference,
 | 
					    AttestationConveyancePreference,
 | 
				
			||||||
    AuthenticatorAttachment,
 | 
					    AuthenticatorAttachment,
 | 
				
			||||||
@ -42,8 +41,7 @@ from authentik.stages.authenticator_webauthn.models import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id
 | 
					from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
PLAN_CONTEXT_WEBAUTHN_CHALLENGE = "goauthentik.io/stages/authenticator_webauthn/challenge"
 | 
					SESSION_KEY_WEBAUTHN_CHALLENGE = "authentik/stages/authenticator_webauthn/challenge"
 | 
				
			||||||
PLAN_CONTEXT_WEBAUTHN_ATTEMPT = "goauthentik.io/stages/authenticator_webauthn/attempt"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AuthenticatorWebAuthnChallenge(WithUserInfoChallenge):
 | 
					class AuthenticatorWebAuthnChallenge(WithUserInfoChallenge):
 | 
				
			||||||
@ -64,7 +62,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def validate_response(self, response: dict) -> dict:
 | 
					    def validate_response(self, response: dict) -> dict:
 | 
				
			||||||
        """Validate webauthn challenge response"""
 | 
					        """Validate webauthn challenge response"""
 | 
				
			||||||
        challenge = self.stage.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE]
 | 
					        challenge = self.request.session[SESSION_KEY_WEBAUTHN_CHALLENGE]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            registration: VerifiedRegistration = verify_registration_response(
 | 
					            registration: VerifiedRegistration = verify_registration_response(
 | 
				
			||||||
@ -73,7 +71,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse):
 | 
				
			|||||||
                expected_rp_id=get_rp_id(self.request),
 | 
					                expected_rp_id=get_rp_id(self.request),
 | 
				
			||||||
                expected_origin=get_origin(self.request),
 | 
					                expected_origin=get_origin(self.request),
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        except WebAuthnException as exc:
 | 
					        except InvalidRegistrationResponse as exc:
 | 
				
			||||||
            self.stage.logger.warning("registration failed", exc=exc)
 | 
					            self.stage.logger.warning("registration failed", exc=exc)
 | 
				
			||||||
            raise ValidationError(f"Registration failed. Error: {exc}") from None
 | 
					            raise ValidationError(f"Registration failed. Error: {exc}") from None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -116,10 +114,9 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
 | 
				
			|||||||
    response_class = AuthenticatorWebAuthnChallengeResponse
 | 
					    response_class = AuthenticatorWebAuthnChallengeResponse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_challenge(self, *args, **kwargs) -> Challenge:
 | 
					    def get_challenge(self, *args, **kwargs) -> Challenge:
 | 
				
			||||||
 | 
					        # clear session variables prior to starting a new registration
 | 
				
			||||||
 | 
					        self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
 | 
				
			||||||
        stage: AuthenticatorWebAuthnStage = self.executor.current_stage
 | 
					        stage: AuthenticatorWebAuthnStage = self.executor.current_stage
 | 
				
			||||||
        self.executor.plan.context.setdefault(PLAN_CONTEXT_WEBAUTHN_ATTEMPT, 0)
 | 
					 | 
				
			||||||
        # clear flow variables prior to starting a new registration
 | 
					 | 
				
			||||||
        self.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None)
 | 
					 | 
				
			||||||
        user = self.get_pending_user()
 | 
					        user = self.get_pending_user()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # library accepts none so we store null in the database, but if there is a value
 | 
					        # library accepts none so we store null in the database, but if there is a value
 | 
				
			||||||
@ -142,7 +139,8 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
 | 
				
			|||||||
            attestation=AttestationConveyancePreference.DIRECT,
 | 
					            attestation=AttestationConveyancePreference.DIRECT,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = registration_options.challenge
 | 
					        self.request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = registration_options.challenge
 | 
				
			||||||
 | 
					        self.request.session.save()
 | 
				
			||||||
        return AuthenticatorWebAuthnChallenge(
 | 
					        return AuthenticatorWebAuthnChallenge(
 | 
				
			||||||
            data={
 | 
					            data={
 | 
				
			||||||
                "registration": loads(options_to_json(registration_options)),
 | 
					                "registration": loads(options_to_json(registration_options)),
 | 
				
			||||||
@ -155,24 +153,6 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
 | 
				
			|||||||
        response.user = self.get_pending_user()
 | 
					        response.user = self.get_pending_user()
 | 
				
			||||||
        return response
 | 
					        return response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def challenge_invalid(self, response):
 | 
					 | 
				
			||||||
        stage: AuthenticatorWebAuthnStage = self.executor.current_stage
 | 
					 | 
				
			||||||
        self.executor.plan.context.setdefault(PLAN_CONTEXT_WEBAUTHN_ATTEMPT, 0)
 | 
					 | 
				
			||||||
        self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_ATTEMPT] += 1
 | 
					 | 
				
			||||||
        if (
 | 
					 | 
				
			||||||
            stage.max_attempts > 0
 | 
					 | 
				
			||||||
            and self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_ATTEMPT] >= stage.max_attempts
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
            return self.executor.stage_invalid(
 | 
					 | 
				
			||||||
                __(
 | 
					 | 
				
			||||||
                    "Exceeded maximum attempts. "
 | 
					 | 
				
			||||||
                    "Contact your {brand} administrator for help.".format(
 | 
					 | 
				
			||||||
                        brand=self.request.brand.branding_title
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        return super().challenge_invalid(response)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
 | 
					    def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
 | 
				
			||||||
        # Webauthn Challenge has already been validated
 | 
					        # Webauthn Challenge has already been validated
 | 
				
			||||||
        webauthn_credential: VerifiedRegistration = response.validated_data["response"]
 | 
					        webauthn_credential: VerifiedRegistration = response.validated_data["response"]
 | 
				
			||||||
@ -199,3 +179,6 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
 | 
				
			|||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return self.executor.stage_invalid("Device with Credential ID already exists.")
 | 
					            return self.executor.stage_invalid("Device with Credential ID already exists.")
 | 
				
			||||||
        return self.executor.stage_ok()
 | 
					        return self.executor.stage_ok()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def cleanup(self):
 | 
				
			||||||
 | 
					        self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
 | 
				
			||||||
 | 
				
			|||||||
@ -18,7 +18,7 @@ from authentik.stages.authenticator_webauthn.models import (
 | 
				
			|||||||
    WebAuthnDevice,
 | 
					    WebAuthnDevice,
 | 
				
			||||||
    WebAuthnDeviceType,
 | 
					    WebAuthnDeviceType,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE
 | 
					from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE
 | 
				
			||||||
from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import
 | 
					from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -57,9 +57,6 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
 | 
				
			|||||||
        response = self.client.get(
 | 
					        response = self.client.get(
 | 
				
			||||||
            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
					            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					 | 
				
			||||||
        plan: FlowPlan = self.client.session[SESSION_KEY_PLAN]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					        self.assertEqual(response.status_code, 200)
 | 
				
			||||||
        session = self.client.session
 | 
					        session = self.client.session
 | 
				
			||||||
        self.assertStageResponse(
 | 
					        self.assertStageResponse(
 | 
				
			||||||
@ -73,7 +70,7 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
 | 
				
			|||||||
                    "name": self.user.username,
 | 
					                    "name": self.user.username,
 | 
				
			||||||
                    "displayName": self.user.name,
 | 
					                    "displayName": self.user.name,
 | 
				
			||||||
                },
 | 
					                },
 | 
				
			||||||
                "challenge": bytes_to_base64url(plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE]),
 | 
					                "challenge": bytes_to_base64url(session[SESSION_KEY_WEBAUTHN_CHALLENGE]),
 | 
				
			||||||
                "pubKeyCredParams": [
 | 
					                "pubKeyCredParams": [
 | 
				
			||||||
                    {"type": "public-key", "alg": -7},
 | 
					                    {"type": "public-key", "alg": -7},
 | 
				
			||||||
                    {"type": "public-key", "alg": -8},
 | 
					                    {"type": "public-key", "alg": -8},
 | 
				
			||||||
@ -100,11 +97,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
 | 
				
			|||||||
        """Test registration"""
 | 
					        """Test registration"""
 | 
				
			||||||
        plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
 | 
					        plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
 | 
				
			||||||
        plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
 | 
					        plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
 | 
				
			||||||
        plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
 | 
					 | 
				
			||||||
            b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        session = self.client.session
 | 
					        session = self.client.session
 | 
				
			||||||
        session[SESSION_KEY_PLAN] = plan
 | 
					        session[SESSION_KEY_PLAN] = plan
 | 
				
			||||||
 | 
					        session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode(
 | 
				
			||||||
 | 
					            b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        session.save()
 | 
					        session.save()
 | 
				
			||||||
        response = self.client.post(
 | 
					        response = self.client.post(
 | 
				
			||||||
            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
					            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
				
			||||||
@ -149,11 +146,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
 | 
					        plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
 | 
				
			||||||
        plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
 | 
					        plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
 | 
				
			||||||
        plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
 | 
					 | 
				
			||||||
            b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        session = self.client.session
 | 
					        session = self.client.session
 | 
				
			||||||
        session[SESSION_KEY_PLAN] = plan
 | 
					        session[SESSION_KEY_PLAN] = plan
 | 
				
			||||||
 | 
					        session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode(
 | 
				
			||||||
 | 
					            b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        session.save()
 | 
					        session.save()
 | 
				
			||||||
        response = self.client.post(
 | 
					        response = self.client.post(
 | 
				
			||||||
            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
					            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
				
			||||||
@ -212,11 +209,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
 | 
					        plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
 | 
				
			||||||
        plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
 | 
					        plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
 | 
				
			||||||
        plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
 | 
					 | 
				
			||||||
            b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        session = self.client.session
 | 
					        session = self.client.session
 | 
				
			||||||
        session[SESSION_KEY_PLAN] = plan
 | 
					        session[SESSION_KEY_PLAN] = plan
 | 
				
			||||||
 | 
					        session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode(
 | 
				
			||||||
 | 
					            b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        session.save()
 | 
					        session.save()
 | 
				
			||||||
        response = self.client.post(
 | 
					        response = self.client.post(
 | 
				
			||||||
            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
					            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
				
			||||||
@ -262,11 +259,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
 | 
					        plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
 | 
				
			||||||
        plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
 | 
					        plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
 | 
				
			||||||
        plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
 | 
					 | 
				
			||||||
            b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        session = self.client.session
 | 
					        session = self.client.session
 | 
				
			||||||
        session[SESSION_KEY_PLAN] = plan
 | 
					        session[SESSION_KEY_PLAN] = plan
 | 
				
			||||||
 | 
					        session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode(
 | 
				
			||||||
 | 
					            b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        session.save()
 | 
					        session.save()
 | 
				
			||||||
        response = self.client.post(
 | 
					        response = self.client.post(
 | 
				
			||||||
            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
					            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
				
			||||||
@ -301,109 +298,3 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
 | 
				
			|||||||
        self.assertEqual(response.status_code, 200)
 | 
					        self.assertEqual(response.status_code, 200)
 | 
				
			||||||
        self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
 | 
					        self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
 | 
				
			||||||
        self.assertTrue(WebAuthnDevice.objects.filter(user=self.user).exists())
 | 
					        self.assertTrue(WebAuthnDevice.objects.filter(user=self.user).exists())
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_register_max_retries(self):
 | 
					 | 
				
			||||||
        """Test registration (exceeding max retries)"""
 | 
					 | 
				
			||||||
        self.stage.max_attempts = 2
 | 
					 | 
				
			||||||
        self.stage.save()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
 | 
					 | 
				
			||||||
        plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
 | 
					 | 
				
			||||||
        plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
 | 
					 | 
				
			||||||
            b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        session = self.client.session
 | 
					 | 
				
			||||||
        session[SESSION_KEY_PLAN] = plan
 | 
					 | 
				
			||||||
        session.save()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # first failed request
 | 
					 | 
				
			||||||
        response = self.client.post(
 | 
					 | 
				
			||||||
            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
					 | 
				
			||||||
            data={
 | 
					 | 
				
			||||||
                "component": "ak-stage-authenticator-webauthn",
 | 
					 | 
				
			||||||
                "response": {
 | 
					 | 
				
			||||||
                    "id": "kqnmrVLnDG-OwsSNHkihYZaNz5s",
 | 
					 | 
				
			||||||
                    "rawId": "kqnmrVLnDG-OwsSNHkihYZaNz5s",
 | 
					 | 
				
			||||||
                    "type": "public-key",
 | 
					 | 
				
			||||||
                    "registrationClientExtensions": "{}",
 | 
					 | 
				
			||||||
                    "response": {
 | 
					 | 
				
			||||||
                        "clientDataJSON": (
 | 
					 | 
				
			||||||
                            "eyJ0eXBlIjoid2ViYXV0aG4uY3JlYXRlIiwiY2hhbGxlbmd"
 | 
					 | 
				
			||||||
                            "lIjoiMDNYb2RpNTRnS3NmblA1STlWRmZoYUdYVlZFMk5VeV"
 | 
					 | 
				
			||||||
                            "pwQkJYbnNfSkkteDZWOVJZMlR3MlFteFJKa2hoNzE3NEVrU"
 | 
					 | 
				
			||||||
                            "mF6VW50SXdqTVZZOWJGRzYwTHciLCJvcmlnaW4iOiJodHRw"
 | 
					 | 
				
			||||||
                            "Oi8vbG9jYWxob3N0OjkwMDAiLCJjcm9zc09yaWdpbiI6ZmF"
 | 
					 | 
				
			||||||
                        ),
 | 
					 | 
				
			||||||
                        "attestationObject": (
 | 
					 | 
				
			||||||
                            "o2NmbXRkbm9uZWdhdHRTdG10oGhhdXRoRGF0YViYSZYN5Yg"
 | 
					 | 
				
			||||||
                            "OjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2NdAAAAAPv8MA"
 | 
					 | 
				
			||||||
                            "cVTk7MjAtuAgVX170AFJKp5q1S5wxvjsLEjR5IoWGWjc-bp"
 | 
					 | 
				
			||||||
                            "QECAyYgASFYIKtcZHPumH37XHs0IM1v3pUBRIqHVV_SE-Lq"
 | 
					 | 
				
			||||||
                            "2zpJAOVXIlgg74Fg_WdB0kuLYqCKbxogkEPaVtR_iR3IyQFIJAXBzds"
 | 
					 | 
				
			||||||
                        ),
 | 
					 | 
				
			||||||
                    },
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            SERVER_NAME="localhost",
 | 
					 | 
				
			||||||
            SERVER_PORT="9000",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertStageResponse(
 | 
					 | 
				
			||||||
            response,
 | 
					 | 
				
			||||||
            flow=self.flow,
 | 
					 | 
				
			||||||
            component="ak-stage-authenticator-webauthn",
 | 
					 | 
				
			||||||
            response_errors={
 | 
					 | 
				
			||||||
                "response": [
 | 
					 | 
				
			||||||
                    {
 | 
					 | 
				
			||||||
                        "string": (
 | 
					 | 
				
			||||||
                            "Registration failed. Error: Unable to decode "
 | 
					 | 
				
			||||||
                            "client_data_json bytes as JSON"
 | 
					 | 
				
			||||||
                        ),
 | 
					 | 
				
			||||||
                        "code": "invalid",
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                ]
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertFalse(WebAuthnDevice.objects.filter(user=self.user).exists())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Second failed request
 | 
					 | 
				
			||||||
        response = self.client.post(
 | 
					 | 
				
			||||||
            reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
 | 
					 | 
				
			||||||
            data={
 | 
					 | 
				
			||||||
                "component": "ak-stage-authenticator-webauthn",
 | 
					 | 
				
			||||||
                "response": {
 | 
					 | 
				
			||||||
                    "id": "kqnmrVLnDG-OwsSNHkihYZaNz5s",
 | 
					 | 
				
			||||||
                    "rawId": "kqnmrVLnDG-OwsSNHkihYZaNz5s",
 | 
					 | 
				
			||||||
                    "type": "public-key",
 | 
					 | 
				
			||||||
                    "registrationClientExtensions": "{}",
 | 
					 | 
				
			||||||
                    "response": {
 | 
					 | 
				
			||||||
                        "clientDataJSON": (
 | 
					 | 
				
			||||||
                            "eyJ0eXBlIjoid2ViYXV0aG4uY3JlYXRlIiwiY2hhbGxlbmd"
 | 
					 | 
				
			||||||
                            "lIjoiMDNYb2RpNTRnS3NmblA1STlWRmZoYUdYVlZFMk5VeV"
 | 
					 | 
				
			||||||
                            "pwQkJYbnNfSkkteDZWOVJZMlR3MlFteFJKa2hoNzE3NEVrU"
 | 
					 | 
				
			||||||
                            "mF6VW50SXdqTVZZOWJGRzYwTHciLCJvcmlnaW4iOiJodHRw"
 | 
					 | 
				
			||||||
                            "Oi8vbG9jYWxob3N0OjkwMDAiLCJjcm9zc09yaWdpbiI6ZmF"
 | 
					 | 
				
			||||||
                        ),
 | 
					 | 
				
			||||||
                        "attestationObject": (
 | 
					 | 
				
			||||||
                            "o2NmbXRkbm9uZWdhdHRTdG10oGhhdXRoRGF0YViYSZYN5Yg"
 | 
					 | 
				
			||||||
                            "OjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2NdAAAAAPv8MA"
 | 
					 | 
				
			||||||
                            "cVTk7MjAtuAgVX170AFJKp5q1S5wxvjsLEjR5IoWGWjc-bp"
 | 
					 | 
				
			||||||
                            "QECAyYgASFYIKtcZHPumH37XHs0IM1v3pUBRIqHVV_SE-Lq"
 | 
					 | 
				
			||||||
                            "2zpJAOVXIlgg74Fg_WdB0kuLYqCKbxogkEPaVtR_iR3IyQFIJAXBzds"
 | 
					 | 
				
			||||||
                        ),
 | 
					 | 
				
			||||||
                    },
 | 
					 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            SERVER_NAME="localhost",
 | 
					 | 
				
			||||||
            SERVER_PORT="9000",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertStageResponse(
 | 
					 | 
				
			||||||
            response,
 | 
					 | 
				
			||||||
            flow=self.flow,
 | 
					 | 
				
			||||||
            component="ak-stage-access-denied",
 | 
					 | 
				
			||||||
            error_message=(
 | 
					 | 
				
			||||||
                "Exceeded maximum attempts. Contact your authentik administrator for help."
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertFalse(WebAuthnDevice.objects.filter(user=self.user).exists())
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -101,9 +101,9 @@ class BoundSessionMiddleware(SessionMiddleware):
 | 
				
			|||||||
            SESSION_KEY_BINDING_GEO, GeoIPBinding.NO_BINDING
 | 
					            SESSION_KEY_BINDING_GEO, GeoIPBinding.NO_BINDING
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        if configured_binding_net != NetworkBinding.NO_BINDING:
 | 
					        if configured_binding_net != NetworkBinding.NO_BINDING:
 | 
				
			||||||
            BoundSessionMiddleware.recheck_session_net(configured_binding_net, last_ip, new_ip)
 | 
					            self.recheck_session_net(configured_binding_net, last_ip, new_ip)
 | 
				
			||||||
        if configured_binding_geo != GeoIPBinding.NO_BINDING:
 | 
					        if configured_binding_geo != GeoIPBinding.NO_BINDING:
 | 
				
			||||||
            BoundSessionMiddleware.recheck_session_geo(configured_binding_geo, last_ip, new_ip)
 | 
					            self.recheck_session_geo(configured_binding_geo, last_ip, new_ip)
 | 
				
			||||||
        # If we got to this point without any error being raised, we need to
 | 
					        # If we got to this point without any error being raised, we need to
 | 
				
			||||||
        # update the last saved IP to the current one
 | 
					        # update the last saved IP to the current one
 | 
				
			||||||
        if SESSION_KEY_BINDING_NET in request.session or SESSION_KEY_BINDING_GEO in request.session:
 | 
					        if SESSION_KEY_BINDING_NET in request.session or SESSION_KEY_BINDING_GEO in request.session:
 | 
				
			||||||
@ -111,8 +111,7 @@ class BoundSessionMiddleware(SessionMiddleware):
 | 
				
			|||||||
            # (== basically requires the user to be logged in)
 | 
					            # (== basically requires the user to be logged in)
 | 
				
			||||||
            request.session[request.session.model.Keys.LAST_IP] = new_ip
 | 
					            request.session[request.session.model.Keys.LAST_IP] = new_ip
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    def recheck_session_net(self, binding: NetworkBinding, last_ip: str, new_ip: str):
 | 
				
			||||||
    def recheck_session_net(binding: NetworkBinding, last_ip: str, new_ip: str):
 | 
					 | 
				
			||||||
        """Check network/ASN binding"""
 | 
					        """Check network/ASN binding"""
 | 
				
			||||||
        last_asn = ASN_CONTEXT_PROCESSOR.asn(last_ip)
 | 
					        last_asn = ASN_CONTEXT_PROCESSOR.asn(last_ip)
 | 
				
			||||||
        new_asn = ASN_CONTEXT_PROCESSOR.asn(new_ip)
 | 
					        new_asn = ASN_CONTEXT_PROCESSOR.asn(new_ip)
 | 
				
			||||||
@ -159,8 +158,7 @@ class BoundSessionMiddleware(SessionMiddleware):
 | 
				
			|||||||
                    new_ip,
 | 
					                    new_ip,
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    def recheck_session_geo(self, binding: GeoIPBinding, last_ip: str, new_ip: str):
 | 
				
			||||||
    def recheck_session_geo(binding: GeoIPBinding, last_ip: str, new_ip: str):
 | 
					 | 
				
			||||||
        """Check GeoIP binding"""
 | 
					        """Check GeoIP binding"""
 | 
				
			||||||
        last_geo = GEOIP_CONTEXT_PROCESSOR.city(last_ip)
 | 
					        last_geo = GEOIP_CONTEXT_PROCESSOR.city(last_ip)
 | 
				
			||||||
        new_geo = GEOIP_CONTEXT_PROCESSOR.city(new_ip)
 | 
					        new_geo = GEOIP_CONTEXT_PROCESSOR.city(new_ip)
 | 
				
			||||||
@ -181,8 +179,8 @@ class BoundSessionMiddleware(SessionMiddleware):
 | 
				
			|||||||
            if last_geo.continent != new_geo.continent:
 | 
					            if last_geo.continent != new_geo.continent:
 | 
				
			||||||
                raise SessionBindingBroken(
 | 
					                raise SessionBindingBroken(
 | 
				
			||||||
                    "geoip.continent",
 | 
					                    "geoip.continent",
 | 
				
			||||||
                    last_geo.continent.to_dict(),
 | 
					                    last_geo.continent,
 | 
				
			||||||
                    new_geo.continent.to_dict(),
 | 
					                    new_geo.continent,
 | 
				
			||||||
                    last_ip,
 | 
					                    last_ip,
 | 
				
			||||||
                    new_ip,
 | 
					                    new_ip,
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
@ -194,8 +192,8 @@ class BoundSessionMiddleware(SessionMiddleware):
 | 
				
			|||||||
            if last_geo.country != new_geo.country:
 | 
					            if last_geo.country != new_geo.country:
 | 
				
			||||||
                raise SessionBindingBroken(
 | 
					                raise SessionBindingBroken(
 | 
				
			||||||
                    "geoip.country",
 | 
					                    "geoip.country",
 | 
				
			||||||
                    last_geo.country.to_dict(),
 | 
					                    last_geo.country,
 | 
				
			||||||
                    new_geo.country.to_dict(),
 | 
					                    new_geo.country,
 | 
				
			||||||
                    last_ip,
 | 
					                    last_ip,
 | 
				
			||||||
                    new_ip,
 | 
					                    new_ip,
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
@ -204,8 +202,8 @@ class BoundSessionMiddleware(SessionMiddleware):
 | 
				
			|||||||
            if last_geo.city != new_geo.city:
 | 
					            if last_geo.city != new_geo.city:
 | 
				
			||||||
                raise SessionBindingBroken(
 | 
					                raise SessionBindingBroken(
 | 
				
			||||||
                    "geoip.city",
 | 
					                    "geoip.city",
 | 
				
			||||||
                    last_geo.city.to_dict(),
 | 
					                    last_geo.city,
 | 
				
			||||||
                    new_geo.city.to_dict(),
 | 
					                    new_geo.city,
 | 
				
			||||||
                    last_ip,
 | 
					                    last_ip,
 | 
				
			||||||
                    new_ip,
 | 
					                    new_ip,
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
				
			|||||||
@ -3,7 +3,6 @@
 | 
				
			|||||||
from time import sleep
 | 
					from time import sleep
 | 
				
			||||||
from unittest.mock import patch
 | 
					from unittest.mock import patch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django.http import HttpRequest
 | 
					 | 
				
			||||||
from django.urls import reverse
 | 
					from django.urls import reverse
 | 
				
			||||||
from django.utils.timezone import now
 | 
					from django.utils.timezone import now
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -18,12 +17,7 @@ from authentik.flows.views.executor import SESSION_KEY_PLAN
 | 
				
			|||||||
from authentik.lib.generators import generate_id
 | 
					from authentik.lib.generators import generate_id
 | 
				
			||||||
from authentik.lib.utils.time import timedelta_from_string
 | 
					from authentik.lib.utils.time import timedelta_from_string
 | 
				
			||||||
from authentik.root.middleware import ClientIPMiddleware
 | 
					from authentik.root.middleware import ClientIPMiddleware
 | 
				
			||||||
from authentik.stages.user_login.middleware import (
 | 
					from authentik.stages.user_login.models import UserLoginStage
 | 
				
			||||||
    BoundSessionMiddleware,
 | 
					 | 
				
			||||||
    SessionBindingBroken,
 | 
					 | 
				
			||||||
    logout_extra,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from authentik.stages.user_login.models import GeoIPBinding, NetworkBinding, UserLoginStage
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestUserLoginStage(FlowTestCase):
 | 
					class TestUserLoginStage(FlowTestCase):
 | 
				
			||||||
@ -198,52 +192,3 @@ class TestUserLoginStage(FlowTestCase):
 | 
				
			|||||||
        self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
 | 
					        self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
 | 
				
			||||||
        response = self.client.get(reverse("authentik_api:application-list"))
 | 
					        response = self.client.get(reverse("authentik_api:application-list"))
 | 
				
			||||||
        self.assertEqual(response.status_code, 403)
 | 
					        self.assertEqual(response.status_code, 403)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_binding_net_break_log(self):
 | 
					 | 
				
			||||||
        """Test logout_extra with exception"""
 | 
					 | 
				
			||||||
        # IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-ASN-Test.json
 | 
					 | 
				
			||||||
        for args, expect in [
 | 
					 | 
				
			||||||
            [[NetworkBinding.BIND_ASN, "8.8.8.8", "8.8.8.8"], ["network.missing"]],
 | 
					 | 
				
			||||||
            [[NetworkBinding.BIND_ASN, "1.0.0.1", "1.128.0.1"], ["network.asn"]],
 | 
					 | 
				
			||||||
            [
 | 
					 | 
				
			||||||
                [NetworkBinding.BIND_ASN_NETWORK, "12.81.96.1", "12.81.128.1"],
 | 
					 | 
				
			||||||
                ["network.asn_network"],
 | 
					 | 
				
			||||||
            ],
 | 
					 | 
				
			||||||
            [[NetworkBinding.BIND_ASN_NETWORK_IP, "1.0.0.1", "1.0.0.2"], ["network.ip"]],
 | 
					 | 
				
			||||||
        ]:
 | 
					 | 
				
			||||||
            with self.subTest(args[0]):
 | 
					 | 
				
			||||||
                with self.assertRaises(SessionBindingBroken) as cm:
 | 
					 | 
				
			||||||
                    BoundSessionMiddleware.recheck_session_net(*args)
 | 
					 | 
				
			||||||
                self.assertEqual(cm.exception.reason, expect[0])
 | 
					 | 
				
			||||||
                # Ensure the request can be logged without throwing errors
 | 
					 | 
				
			||||||
                self.client.force_login(self.user)
 | 
					 | 
				
			||||||
                request = HttpRequest()
 | 
					 | 
				
			||||||
                request.session = self.client.session
 | 
					 | 
				
			||||||
                request.user = self.user
 | 
					 | 
				
			||||||
                logout_extra(request, cm.exception)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_binding_geo_break_log(self):
 | 
					 | 
				
			||||||
        """Test logout_extra with exception"""
 | 
					 | 
				
			||||||
        # IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
 | 
					 | 
				
			||||||
        for args, expect in [
 | 
					 | 
				
			||||||
            [[GeoIPBinding.BIND_CONTINENT, "8.8.8.8", "8.8.8.8"], ["geoip.missing"]],
 | 
					 | 
				
			||||||
            [[GeoIPBinding.BIND_CONTINENT, "2.125.160.216", "67.43.156.1"], ["geoip.continent"]],
 | 
					 | 
				
			||||||
            [
 | 
					 | 
				
			||||||
                [GeoIPBinding.BIND_CONTINENT_COUNTRY, "81.2.69.142", "89.160.20.112"],
 | 
					 | 
				
			||||||
                ["geoip.country"],
 | 
					 | 
				
			||||||
            ],
 | 
					 | 
				
			||||||
            [
 | 
					 | 
				
			||||||
                [GeoIPBinding.BIND_CONTINENT_COUNTRY_CITY, "2.125.160.216", "81.2.69.142"],
 | 
					 | 
				
			||||||
                ["geoip.city"],
 | 
					 | 
				
			||||||
            ],
 | 
					 | 
				
			||||||
        ]:
 | 
					 | 
				
			||||||
            with self.subTest(args[0]):
 | 
					 | 
				
			||||||
                with self.assertRaises(SessionBindingBroken) as cm:
 | 
					 | 
				
			||||||
                    BoundSessionMiddleware.recheck_session_geo(*args)
 | 
					 | 
				
			||||||
                self.assertEqual(cm.exception.reason, expect[0])
 | 
					 | 
				
			||||||
                # Ensure the request can be logged without throwing errors
 | 
					 | 
				
			||||||
                self.client.force_login(self.user)
 | 
					 | 
				
			||||||
                request = HttpRequest()
 | 
					 | 
				
			||||||
                request.session = self.client.session
 | 
					 | 
				
			||||||
                request.user = self.user
 | 
					 | 
				
			||||||
                logout_extra(request, cm.exception)
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -1,7 +1,6 @@
 | 
				
			|||||||
"""Serializer for tenants models"""
 | 
					"""Serializer for tenants models"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from django_tenants.utils import get_public_schema_name
 | 
					from django_tenants.utils import get_public_schema_name
 | 
				
			||||||
from rest_framework.fields import JSONField
 | 
					 | 
				
			||||||
from rest_framework.generics import RetrieveUpdateAPIView
 | 
					from rest_framework.generics import RetrieveUpdateAPIView
 | 
				
			||||||
from rest_framework.permissions import SAFE_METHODS
 | 
					from rest_framework.permissions import SAFE_METHODS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -13,8 +12,6 @@ from authentik.tenants.models import Tenant
 | 
				
			|||||||
class SettingsSerializer(ModelSerializer):
 | 
					class SettingsSerializer(ModelSerializer):
 | 
				
			||||||
    """Settings Serializer"""
 | 
					    """Settings Serializer"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    footer_links = JSONField(required=False)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    class Meta:
 | 
					    class Meta:
 | 
				
			||||||
        model = Tenant
 | 
					        model = Tenant
 | 
				
			||||||
        fields = [
 | 
					        fields = [
 | 
				
			||||||
 | 
				
			|||||||
@ -16,7 +16,6 @@ def check_embedded_outpost_disabled(app_configs, **kwargs):
 | 
				
			|||||||
                "Embedded outpost must be disabled when tenants API is enabled.",
 | 
					                "Embedded outpost must be disabled when tenants API is enabled.",
 | 
				
			||||||
                hint="Disable embedded outpost by setting outposts.disable_embedded_outpost to "
 | 
					                hint="Disable embedded outpost by setting outposts.disable_embedded_outpost to "
 | 
				
			||||||
                "True, or disable the tenants API by setting tenants.enabled to False",
 | 
					                "True, or disable the tenants API by setting tenants.enabled to False",
 | 
				
			||||||
                id="ak.tenants.E001",
 | 
					 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
    return []
 | 
					    return []
 | 
				
			||||||
 | 
				
			|||||||
@ -2,7 +2,7 @@
 | 
				
			|||||||
    "$schema": "http://json-schema.org/draft-07/schema",
 | 
					    "$schema": "http://json-schema.org/draft-07/schema",
 | 
				
			||||||
    "$id": "https://goauthentik.io/blueprints/schema.json",
 | 
					    "$id": "https://goauthentik.io/blueprints/schema.json",
 | 
				
			||||||
    "type": "object",
 | 
					    "type": "object",
 | 
				
			||||||
    "title": "authentik 2025.6.2 Blueprint schema",
 | 
					    "title": "authentik 2025.6.1 Blueprint schema",
 | 
				
			||||||
    "required": [
 | 
					    "required": [
 | 
				
			||||||
        "version",
 | 
					        "version",
 | 
				
			||||||
        "entries"
 | 
					        "entries"
 | 
				
			||||||
@ -6628,16 +6628,11 @@
 | 
				
			|||||||
                    "title": "Severity",
 | 
					                    "title": "Severity",
 | 
				
			||||||
                    "description": "Controls which severity level the created notifications will have."
 | 
					                    "description": "Controls which severity level the created notifications will have."
 | 
				
			||||||
                },
 | 
					                },
 | 
				
			||||||
                "destination_group": {
 | 
					                "group": {
 | 
				
			||||||
                    "type": "string",
 | 
					                    "type": "string",
 | 
				
			||||||
                    "format": "uuid",
 | 
					                    "format": "uuid",
 | 
				
			||||||
                    "title": "Destination group",
 | 
					                    "title": "Group",
 | 
				
			||||||
                    "description": "Define which group of users this notification should be sent and shown to. If left empty, Notification won't ben sent."
 | 
					                    "description": "Define which group of users this notification should be sent and shown to. If left empty, Notification won't ben sent."
 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
                "destination_event_user": {
 | 
					 | 
				
			||||||
                    "type": "boolean",
 | 
					 | 
				
			||||||
                    "title": "Destination event user",
 | 
					 | 
				
			||||||
                    "description": "When enabled, notification will be sent to user the user that triggered the event.When destination_group is configured, notification is sent to both."
 | 
					 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            },
 | 
					            },
 | 
				
			||||||
            "required": []
 | 
					            "required": []
 | 
				
			||||||
@ -7345,7 +7340,6 @@
 | 
				
			|||||||
                        "authentik.enterprise.providers.google_workspace",
 | 
					                        "authentik.enterprise.providers.google_workspace",
 | 
				
			||||||
                        "authentik.enterprise.providers.microsoft_entra",
 | 
					                        "authentik.enterprise.providers.microsoft_entra",
 | 
				
			||||||
                        "authentik.enterprise.providers.ssf",
 | 
					                        "authentik.enterprise.providers.ssf",
 | 
				
			||||||
                        "authentik.enterprise.search",
 | 
					 | 
				
			||||||
                        "authentik.enterprise.stages.authenticator_endpoint_gdtc",
 | 
					                        "authentik.enterprise.stages.authenticator_endpoint_gdtc",
 | 
				
			||||||
                        "authentik.enterprise.stages.mtls",
 | 
					                        "authentik.enterprise.stages.mtls",
 | 
				
			||||||
                        "authentik.enterprise.stages.source",
 | 
					                        "authentik.enterprise.stages.source",
 | 
				
			||||||
@ -13310,12 +13304,6 @@
 | 
				
			|||||||
                        "format": "uuid"
 | 
					                        "format": "uuid"
 | 
				
			||||||
                    },
 | 
					                    },
 | 
				
			||||||
                    "title": "Device type restrictions"
 | 
					                    "title": "Device type restrictions"
 | 
				
			||||||
                },
 | 
					 | 
				
			||||||
                "max_attempts": {
 | 
					 | 
				
			||||||
                    "type": "integer",
 | 
					 | 
				
			||||||
                    "minimum": 0,
 | 
					 | 
				
			||||||
                    "maximum": 2147483647,
 | 
					 | 
				
			||||||
                    "title": "Max attempts"
 | 
					 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            },
 | 
					            },
 | 
				
			||||||
            "required": []
 | 
					            "required": []
 | 
				
			||||||
 | 
				
			|||||||
@ -31,7 +31,7 @@ services:
 | 
				
			|||||||
    volumes:
 | 
					    volumes:
 | 
				
			||||||
      - redis:/data
 | 
					      - redis:/data
 | 
				
			||||||
  server:
 | 
					  server:
 | 
				
			||||||
    image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2025.6.2}
 | 
					    image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2025.6.1}
 | 
				
			||||||
    restart: unless-stopped
 | 
					    restart: unless-stopped
 | 
				
			||||||
    command: server
 | 
					    command: server
 | 
				
			||||||
    environment:
 | 
					    environment:
 | 
				
			||||||
@ -55,7 +55,7 @@ services:
 | 
				
			|||||||
      redis:
 | 
					      redis:
 | 
				
			||||||
        condition: service_healthy
 | 
					        condition: service_healthy
 | 
				
			||||||
  worker:
 | 
					  worker:
 | 
				
			||||||
    image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2025.6.2}
 | 
					    image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2025.6.1}
 | 
				
			||||||
    restart: unless-stopped
 | 
					    restart: unless-stopped
 | 
				
			||||||
    command: worker
 | 
					    command: worker
 | 
				
			||||||
    environment:
 | 
					    environment:
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										8
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								go.mod
									
									
									
									
									
								
							@ -6,7 +6,7 @@ require (
 | 
				
			|||||||
	beryju.io/ldap v0.1.0
 | 
						beryju.io/ldap v0.1.0
 | 
				
			||||||
	github.com/avast/retry-go/v4 v4.6.1
 | 
						github.com/avast/retry-go/v4 v4.6.1
 | 
				
			||||||
	github.com/coreos/go-oidc/v3 v3.14.1
 | 
						github.com/coreos/go-oidc/v3 v3.14.1
 | 
				
			||||||
	github.com/getsentry/sentry-go v0.34.0
 | 
						github.com/getsentry/sentry-go v0.33.0
 | 
				
			||||||
	github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
 | 
						github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
 | 
				
			||||||
	github.com/go-ldap/ldap/v3 v3.4.11
 | 
						github.com/go-ldap/ldap/v3 v3.4.11
 | 
				
			||||||
	github.com/go-openapi/runtime v0.28.0
 | 
						github.com/go-openapi/runtime v0.28.0
 | 
				
			||||||
@ -18,18 +18,18 @@ require (
 | 
				
			|||||||
	github.com/gorilla/sessions v1.4.0
 | 
						github.com/gorilla/sessions v1.4.0
 | 
				
			||||||
	github.com/gorilla/websocket v1.5.3
 | 
						github.com/gorilla/websocket v1.5.3
 | 
				
			||||||
	github.com/grafana/pyroscope-go v1.2.2
 | 
						github.com/grafana/pyroscope-go v1.2.2
 | 
				
			||||||
	github.com/jellydator/ttlcache/v3 v3.4.0
 | 
						github.com/jellydator/ttlcache/v3 v3.3.0
 | 
				
			||||||
	github.com/mitchellh/mapstructure v1.5.0
 | 
						github.com/mitchellh/mapstructure v1.5.0
 | 
				
			||||||
	github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484
 | 
						github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484
 | 
				
			||||||
	github.com/pires/go-proxyproto v0.8.1
 | 
						github.com/pires/go-proxyproto v0.8.1
 | 
				
			||||||
	github.com/prometheus/client_golang v1.22.0
 | 
						github.com/prometheus/client_golang v1.22.0
 | 
				
			||||||
	github.com/redis/go-redis/v9 v9.11.0
 | 
						github.com/redis/go-redis/v9 v9.10.0
 | 
				
			||||||
	github.com/sethvargo/go-envconfig v1.3.0
 | 
						github.com/sethvargo/go-envconfig v1.3.0
 | 
				
			||||||
	github.com/sirupsen/logrus v1.9.3
 | 
						github.com/sirupsen/logrus v1.9.3
 | 
				
			||||||
	github.com/spf13/cobra v1.9.1
 | 
						github.com/spf13/cobra v1.9.1
 | 
				
			||||||
	github.com/stretchr/testify v1.10.0
 | 
						github.com/stretchr/testify v1.10.0
 | 
				
			||||||
	github.com/wwt/guac v1.3.2
 | 
						github.com/wwt/guac v1.3.2
 | 
				
			||||||
	goauthentik.io/api/v3 v3.2025062.5
 | 
						goauthentik.io/api/v3 v3.2025061.2
 | 
				
			||||||
	golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
 | 
						golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
 | 
				
			||||||
	golang.org/x/oauth2 v0.30.0
 | 
						golang.org/x/oauth2 v0.30.0
 | 
				
			||||||
	golang.org/x/sync v0.15.0
 | 
						golang.org/x/sync v0.15.0
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										16
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								go.sum
									
									
									
									
									
								
							@ -71,8 +71,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m
 | 
				
			|||||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
 | 
					github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
 | 
				
			||||||
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
 | 
					github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
 | 
				
			||||||
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
 | 
					github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
 | 
				
			||||||
github.com/getsentry/sentry-go v0.34.0 h1:1FCHBVp8TfSc8L10zqSwXUZNiOSF+10qw4czjarTiY4=
 | 
					github.com/getsentry/sentry-go v0.33.0 h1:YWyDii0KGVov3xOaamOnF0mjOrqSjBqwv48UEzn7QFg=
 | 
				
			||||||
github.com/getsentry/sentry-go v0.34.0/go.mod h1:C55omcY9ChRQIUcVcGcs+Zdy4ZpQGvNJ7JYHIoSWOtE=
 | 
					github.com/getsentry/sentry-go v0.33.0/go.mod h1:C55omcY9ChRQIUcVcGcs+Zdy4ZpQGvNJ7JYHIoSWOtE=
 | 
				
			||||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
 | 
					github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
 | 
				
			||||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
 | 
					github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
 | 
				
			||||||
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
 | 
					github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
 | 
				
			||||||
@ -203,8 +203,8 @@ github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh6
 | 
				
			|||||||
github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs=
 | 
					github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs=
 | 
				
			||||||
github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY=
 | 
					github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY=
 | 
				
			||||||
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
 | 
					github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
 | 
				
			||||||
github.com/jellydator/ttlcache/v3 v3.4.0 h1:YS4P125qQS0tNhtL6aeYkheEaB/m8HCqdMMP4mnWdTY=
 | 
					github.com/jellydator/ttlcache/v3 v3.3.0 h1:BdoC9cE81qXfrxeb9eoJi9dWrdhSuwXMAnHTbnBm4Wc=
 | 
				
			||||||
github.com/jellydator/ttlcache/v3 v3.4.0/go.mod h1:Hw9EgjymziQD3yGsQdf1FqFdpp7YjFMd4Srg5EJlgD4=
 | 
					github.com/jellydator/ttlcache/v3 v3.3.0/go.mod h1:bj2/e0l4jRnQdrnSTaGTsh4GSXvMjQcy41i7th0GVGw=
 | 
				
			||||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
 | 
					github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
 | 
				
			||||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
 | 
					github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
 | 
				
			||||||
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
 | 
					github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
 | 
				
			||||||
@ -251,8 +251,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
 | 
				
			|||||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
 | 
					github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
 | 
				
			||||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
 | 
					github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
 | 
				
			||||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
 | 
					github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
 | 
				
			||||||
github.com/redis/go-redis/v9 v9.11.0 h1:E3S08Gl/nJNn5vkxd2i78wZxWAPNZgUNTp8WIJUAiIs=
 | 
					github.com/redis/go-redis/v9 v9.10.0 h1:FxwK3eV8p/CQa0Ch276C7u2d0eNC9kCmAYQ7mCXCzVs=
 | 
				
			||||||
github.com/redis/go-redis/v9 v9.11.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
 | 
					github.com/redis/go-redis/v9 v9.10.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
 | 
				
			||||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
 | 
					github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
 | 
				
			||||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
 | 
					github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
 | 
				
			||||||
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
 | 
					github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
 | 
				
			||||||
@ -298,8 +298,8 @@ go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y
 | 
				
			|||||||
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
 | 
					go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
 | 
				
			||||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
 | 
					go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
 | 
				
			||||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
 | 
					go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
 | 
				
			||||||
goauthentik.io/api/v3 v3.2025062.5 h1:+eQe3S+9WxrO0QczbSQUhtfnCB1w2rse5wmgMkcRUio=
 | 
					goauthentik.io/api/v3 v3.2025061.2 h1:bKmrl82Gz6J8lz3f+QIH9g+MEkl3MvkMXF34GktesA0=
 | 
				
			||||||
goauthentik.io/api/v3 v3.2025062.5/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
 | 
					goauthentik.io/api/v3 v3.2025061.2/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
 | 
				
			||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 | 
					golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 | 
				
			||||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 | 
					golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 | 
				
			||||||
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 | 
					golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 | 
				
			||||||
 | 
				
			|||||||
@ -33,4 +33,4 @@ func UserAgent() string {
 | 
				
			|||||||
	return fmt.Sprintf("authentik@%s", FullVersion())
 | 
						return fmt.Sprintf("authentik@%s", FullVersion())
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const VERSION = "2025.6.2"
 | 
					const VERSION = "2025.6.1"
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										8
									
								
								lifecycle/aws/package-lock.json
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										8
									
								
								lifecycle/aws/package-lock.json
									
									
									
										generated
									
									
									
								
							@ -9,7 +9,7 @@
 | 
				
			|||||||
            "version": "0.0.0",
 | 
					            "version": "0.0.0",
 | 
				
			||||||
            "license": "MIT",
 | 
					            "license": "MIT",
 | 
				
			||||||
            "devDependencies": {
 | 
					            "devDependencies": {
 | 
				
			||||||
                "aws-cdk": "^2.1019.1",
 | 
					                "aws-cdk": "^2.1018.1",
 | 
				
			||||||
                "cross-env": "^7.0.3"
 | 
					                "cross-env": "^7.0.3"
 | 
				
			||||||
            },
 | 
					            },
 | 
				
			||||||
            "engines": {
 | 
					            "engines": {
 | 
				
			||||||
@ -17,9 +17,9 @@
 | 
				
			|||||||
            }
 | 
					            }
 | 
				
			||||||
        },
 | 
					        },
 | 
				
			||||||
        "node_modules/aws-cdk": {
 | 
					        "node_modules/aws-cdk": {
 | 
				
			||||||
            "version": "2.1019.1",
 | 
					            "version": "2.1018.1",
 | 
				
			||||||
            "resolved": "https://registry.npmjs.org/aws-cdk/-/aws-cdk-2.1019.1.tgz",
 | 
					            "resolved": "https://registry.npmjs.org/aws-cdk/-/aws-cdk-2.1018.1.tgz",
 | 
				
			||||||
            "integrity": "sha512-G2jxKuTsYTrYZX80CDApCrKcZ+AuFxxd+b0dkb0KEkfUsela7RqrDGLm5wOzSCIc3iH6GocR8JDVZuJ+0nNuKg==",
 | 
					            "integrity": "sha512-kFPRox5kSm+ktJ451o0ng9rD+60p5Kt1CZIWw8kXnvqbsxN2xv6qbmyWSXw7sGVXVwqrRKVj+71/JeDr+LMAZw==",
 | 
				
			||||||
            "dev": true,
 | 
					            "dev": true,
 | 
				
			||||||
            "license": "Apache-2.0",
 | 
					            "license": "Apache-2.0",
 | 
				
			||||||
            "bin": {
 | 
					            "bin": {
 | 
				
			||||||
 | 
				
			|||||||
@ -10,7 +10,7 @@
 | 
				
			|||||||
        "node": ">=20"
 | 
					        "node": ">=20"
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
    "devDependencies": {
 | 
					    "devDependencies": {
 | 
				
			||||||
        "aws-cdk": "^2.1019.1",
 | 
					        "aws-cdk": "^2.1018.1",
 | 
				
			||||||
        "cross-env": "^7.0.3"
 | 
					        "cross-env": "^7.0.3"
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -26,7 +26,7 @@ Parameters:
 | 
				
			|||||||
    Description: authentik Docker image
 | 
					    Description: authentik Docker image
 | 
				
			||||||
  AuthentikVersion:
 | 
					  AuthentikVersion:
 | 
				
			||||||
    Type: String
 | 
					    Type: String
 | 
				
			||||||
    Default: 2025.6.2
 | 
					    Default: 2025.6.1
 | 
				
			||||||
    Description: authentik Docker image tag
 | 
					    Description: authentik Docker image tag
 | 
				
			||||||
  AuthentikServerCPU:
 | 
					  AuthentikServerCPU:
 | 
				
			||||||
    Type: Number
 | 
					    Type: Number
 | 
				
			||||||
 | 
				
			|||||||
@ -7,6 +7,8 @@ from pathlib import Path
 | 
				
			|||||||
from tempfile import gettempdir
 | 
					from tempfile import gettempdir
 | 
				
			||||||
from typing import TYPE_CHECKING
 | 
					from typing import TYPE_CHECKING
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from cryptography.hazmat.backends.openssl.backend import backend
 | 
				
			||||||
 | 
					from defusedxml import defuse_stdlib
 | 
				
			||||||
from prometheus_client.values import MultiProcessValue
 | 
					from prometheus_client.values import MultiProcessValue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik import get_full_version
 | 
					from authentik import get_full_version
 | 
				
			||||||
@ -16,7 +18,6 @@ from authentik.lib.logging import get_logger_config
 | 
				
			|||||||
from authentik.lib.utils.http import get_http_session
 | 
					from authentik.lib.utils.http import get_http_session
 | 
				
			||||||
from authentik.lib.utils.reflection import get_env
 | 
					from authentik.lib.utils.reflection import get_env
 | 
				
			||||||
from authentik.root.install_id import get_install_id_raw
 | 
					from authentik.root.install_id import get_install_id_raw
 | 
				
			||||||
from authentik.root.setup import setup
 | 
					 | 
				
			||||||
from lifecycle.migrate import run_migrations
 | 
					from lifecycle.migrate import run_migrations
 | 
				
			||||||
from lifecycle.wait_for_db import wait_for_db
 | 
					from lifecycle.wait_for_db import wait_for_db
 | 
				
			||||||
from lifecycle.worker import DjangoUvicornWorker
 | 
					from lifecycle.worker import DjangoUvicornWorker
 | 
				
			||||||
@ -27,7 +28,10 @@ if TYPE_CHECKING:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    from authentik.root.asgi import AuthentikAsgi
 | 
					    from authentik.root.asgi import AuthentikAsgi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
setup()
 | 
					defuse_stdlib()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if CONFIG.get_bool("compliance.fips.enabled", False):
 | 
				
			||||||
 | 
					    backend._enable_fips()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
wait_for_db()
 | 
					wait_for_db()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -10,7 +10,7 @@ from typing import Any
 | 
				
			|||||||
from psycopg import Connection, Cursor, connect
 | 
					from psycopg import Connection, Cursor, connect
 | 
				
			||||||
from structlog.stdlib import get_logger
 | 
					from structlog.stdlib import get_logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from authentik.lib.config import CONFIG, django_db_config
 | 
					from authentik.lib.config import CONFIG
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOGGER = get_logger()
 | 
					LOGGER = get_logger()
 | 
				
			||||||
ADV_LOCK_UID = 1000
 | 
					ADV_LOCK_UID = 1000
 | 
				
			||||||
@ -115,13 +115,9 @@ def run_migrations():
 | 
				
			|||||||
        execute_from_command_line(["", "migrate_schemas"])
 | 
					        execute_from_command_line(["", "migrate_schemas"])
 | 
				
			||||||
        if CONFIG.get_bool("tenants.enabled", False):
 | 
					        if CONFIG.get_bool("tenants.enabled", False):
 | 
				
			||||||
            execute_from_command_line(["", "migrate_schemas", "--schema", "template", "--tenant"])
 | 
					            execute_from_command_line(["", "migrate_schemas", "--schema", "template", "--tenant"])
 | 
				
			||||||
        # Run django system checks for all databases
 | 
					        execute_from_command_line(
 | 
				
			||||||
        check_args = ["", "check"]
 | 
					            ["", "check"] + ([] if CONFIG.get_bool("debug") else ["--deploy"])
 | 
				
			||||||
        for label in django_db_config(CONFIG).keys():
 | 
					        )
 | 
				
			||||||
            check_args.append(f"--database={label}")
 | 
					 | 
				
			||||||
        if not CONFIG.get_bool("debug"):
 | 
					 | 
				
			||||||
            check_args.append("--deploy")
 | 
					 | 
				
			||||||
        execute_from_command_line(check_args)
 | 
					 | 
				
			||||||
    finally:
 | 
					    finally:
 | 
				
			||||||
        release_lock(curr)
 | 
					        release_lock(curr)
 | 
				
			||||||
        curr.close()
 | 
					        curr.close()
 | 
				
			||||||
 | 
				
			|||||||
@ -8,7 +8,7 @@ msgid ""
 | 
				
			|||||||
msgstr ""
 | 
					msgstr ""
 | 
				
			||||||
"Project-Id-Version: PACKAGE VERSION\n"
 | 
					"Project-Id-Version: PACKAGE VERSION\n"
 | 
				
			||||||
"Report-Msgid-Bugs-To: \n"
 | 
					"Report-Msgid-Bugs-To: \n"
 | 
				
			||||||
"POT-Creation-Date: 2025-06-25 00:10+0000\n"
 | 
					"POT-Creation-Date: 2025-06-04 00:12+0000\n"
 | 
				
			||||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
 | 
					"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
 | 
				
			||||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
 | 
					"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
 | 
				
			||||||
"Language-Team: LANGUAGE <LL@li.org>\n"
 | 
					"Language-Team: LANGUAGE <LL@li.org>\n"
 | 
				
			||||||
@ -109,6 +109,10 @@ msgstr ""
 | 
				
			|||||||
msgid "User does not have access to application."
 | 
					msgid "User does not have access to application."
 | 
				
			||||||
msgstr ""
 | 
					msgstr ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#: authentik/core/api/devices.py
 | 
				
			||||||
 | 
					msgid "Extra description not available"
 | 
				
			||||||
 | 
					msgstr ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#: authentik/core/api/groups.py
 | 
					#: authentik/core/api/groups.py
 | 
				
			||||||
msgid "Cannot set group as parent of itself."
 | 
					msgid "Cannot set group as parent of itself."
 | 
				
			||||||
msgstr ""
 | 
					msgstr ""
 | 
				
			||||||
@ -759,12 +763,6 @@ msgid ""
 | 
				
			|||||||
"If left empty, Notification won't ben sent."
 | 
					"If left empty, Notification won't ben sent."
 | 
				
			||||||
msgstr ""
 | 
					msgstr ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#: authentik/events/models.py
 | 
					 | 
				
			||||||
msgid ""
 | 
					 | 
				
			||||||
"When enabled, notification will be sent to user the user that triggered the "
 | 
					 | 
				
			||||||
"event.When destination_group is configured, notification is sent to both."
 | 
					 | 
				
			||||||
msgstr ""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#: authentik/events/models.py
 | 
					#: authentik/events/models.py
 | 
				
			||||||
msgid "Notification Rule"
 | 
					msgid "Notification Rule"
 | 
				
			||||||
msgstr ""
 | 
					msgstr ""
 | 
				
			||||||
 | 
				
			|||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user