Compare commits

..

2 Commits

Author SHA1 Message Date
7549a6b83d add cause for oauth authorization errors
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-16 03:04:20 +02:00
bb45b714e2 fix incorrect tests/add more
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-16 02:54:16 +02:00
382 changed files with 9525 additions and 8520 deletions

View File

@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 2025.6.3 current_version = 2025.6.1
tag = True tag = True
commit = True commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
@ -21,8 +21,6 @@ optional_value = final
[bumpversion:file:package.json] [bumpversion:file:package.json]
[bumpversion:file:package-lock.json]
[bumpversion:file:docker-compose.yml] [bumpversion:file:docker-compose.yml]
[bumpversion:file:schema.yml] [bumpversion:file:schema.yml]
@ -33,4 +31,6 @@ optional_value = final
[bumpversion:file:internal/constants/constants.go] [bumpversion:file:internal/constants/constants.go]
[bumpversion:file:web/src/common/constants.ts]
[bumpversion:file:lifecycle/aws/template.yaml] [bumpversion:file:lifecycle/aws/template.yaml]

View File

@ -7,9 +7,6 @@ charset = utf-8
trim_trailing_whitespace = true trim_trailing_whitespace = true
insert_final_newline = true insert_final_newline = true
[*.toml]
indent_size = 2
[*.html] [*.html]
indent_size = 2 indent_size = 2

View File

@ -38,8 +38,6 @@ jobs:
# Needed for attestation # Needed for attestation
id-token: write id-token: write
attestations: write attestations: write
# Needed for checkout
contents: read
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: docker/setup-qemu-action@v3.6.0 - uses: docker/setup-qemu-action@v3.6.0

View File

@ -9,15 +9,14 @@ on:
jobs: jobs:
test-container: test-container:
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
version: version:
- docs - docs
- version-2025-4
- version-2025-2 - version-2025-2
- version-2024-12
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- run: | - run: |

View File

@ -202,7 +202,7 @@ jobs:
uses: actions/cache@v4 uses: actions/cache@v4
with: with:
path: web/dist path: web/dist
key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b
- name: prepare web ui - name: prepare web ui
if: steps.cache-web.outputs.cache-hit != 'true' if: steps.cache-web.outputs.cache-hit != 'true'
working-directory: web working-directory: web
@ -247,13 +247,11 @@ jobs:
# Needed for attestation # Needed for attestation
id-token: write id-token: write
attestations: write attestations: write
# Needed for checkout
contents: read
needs: ci-core-mark needs: ci-core-mark
uses: ./.github/workflows/_reusable-docker-build.yaml uses: ./.github/workflows/_reusable-docker-build.yaml
secrets: inherit secrets: inherit
with: with:
image_name: ${{ github.repository == 'goauthentik/authentik-internal' && 'ghcr.io/goauthentik/internal-server' || 'ghcr.io/goauthentik/dev-server' }} image_name: ghcr.io/goauthentik/dev-server
release: false release: false
pr-comment: pr-comment:
needs: needs:

View File

@ -59,7 +59,6 @@ jobs:
with: with:
jobs: ${{ toJSON(needs) }} jobs: ${{ toJSON(needs) }}
build-container: build-container:
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
timeout-minutes: 120 timeout-minutes: 120
needs: needs:
- ci-outpost-mark - ci-outpost-mark

View File

@ -41,29 +41,7 @@ jobs:
- name: test - name: test
working-directory: website/ working-directory: website/
run: npm test run: npm test
build:
runs-on: ubuntu-latest
name: ${{ matrix.job }}
strategy:
fail-fast: false
matrix:
job:
- build
- build:integrations
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version-file: website/package.json
cache: "npm"
cache-dependency-path: website/package-lock.json
- working-directory: website/
run: npm ci
- name: build
working-directory: website/
run: npm run ${{ matrix.job }}
build-container: build-container:
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
# Needed to upload container images to ghcr.io # Needed to upload container images to ghcr.io
@ -116,11 +94,9 @@ jobs:
needs: needs:
- lint - lint
- test - test
- build
- build-container - build-container
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: re-actors/alls-green@release/v1 - uses: re-actors/alls-green@release/v1
with: with:
jobs: ${{ toJSON(needs) }} jobs: ${{ toJSON(needs) }}
allowed-skips: ${{ github.repository == 'goauthentik/authentik-internal' && 'build-container' || '[]' }}

View File

@ -2,7 +2,7 @@ name: "CodeQL"
on: on:
push: push:
branches: [main, next, version*] branches: [main, "*", next, version*]
pull_request: pull_request:
branches: [main] branches: [main]
schedule: schedule:

View File

@ -1,21 +0,0 @@
name: "authentik-repo-mirror-cleanup"
on:
workflow_dispatch:
jobs:
to_internal:
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- if: ${{ env.MIRROR_KEY != '' }}
uses: BeryJu/repository-mirroring-action@5cf300935bc2e068f73ea69bcc411a8a997208eb
with:
target_repo_url: git@github.com:goauthentik/authentik-internal.git
ssh_private_key: ${{ secrets.GH_MIRROR_KEY }}
args: --tags --force --prune
env:
MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }}

View File

@ -11,10 +11,11 @@ jobs:
with: with:
fetch-depth: 0 fetch-depth: 0
- if: ${{ env.MIRROR_KEY != '' }} - if: ${{ env.MIRROR_KEY != '' }}
uses: BeryJu/repository-mirroring-action@5cf300935bc2e068f73ea69bcc411a8a997208eb uses: pixta-dev/repository-mirroring-action@v1
with: with:
target_repo_url: git@github.com:goauthentik/authentik-internal.git target_repo_url:
ssh_private_key: ${{ secrets.GH_MIRROR_KEY }} git@github.com:goauthentik/authentik-internal.git
args: --tags --force ssh_private_key:
${{ secrets.GH_MIRROR_KEY }}
env: env:
MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }} MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }}

View File

@ -16,7 +16,6 @@ env:
jobs: jobs:
compile: compile:
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- id: generate_token - id: generate_token

View File

@ -6,15 +6,13 @@
"!Context scalar", "!Context scalar",
"!Enumerate sequence", "!Enumerate sequence",
"!Env scalar", "!Env scalar",
"!Env sequence",
"!Find sequence", "!Find sequence",
"!Format sequence", "!Format sequence",
"!If sequence", "!If sequence",
"!Index scalar", "!Index scalar",
"!KeyOf scalar", "!KeyOf scalar",
"!Value scalar", "!Value scalar",
"!AtIndex scalar", "!AtIndex scalar"
"!ParseJSON scalar"
], ],
"typescript.preferences.importModuleSpecifier": "non-relative", "typescript.preferences.importModuleSpecifier": "non-relative",
"typescript.preferences.importModuleSpecifierEnding": "index", "typescript.preferences.importModuleSpecifierEnding": "index",

View File

@ -75,9 +75,9 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
/bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0" /bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
# Stage 4: Download uv # Stage 4: Download uv
FROM ghcr.io/astral-sh/uv:0.7.17 AS uv FROM ghcr.io/astral-sh/uv:0.7.13 AS uv
# Stage 5: Base python image # Stage 5: Base python image
FROM ghcr.io/goauthentik/fips-python:3.13.5-slim-bookworm-fips AS python-base FROM ghcr.io/goauthentik/fips-python:3.13.4-slim-bookworm-fips AS python-base
ENV VENV_PATH="/ak-root/.venv" \ ENV VENV_PATH="/ak-root/.venv" \
PATH="/lifecycle:/ak-root/.venv/bin:$PATH" \ PATH="/lifecycle:/ak-root/.venv/bin:$PATH" \

View File

@ -86,10 +86,6 @@ dev-create-db:
dev-reset: dev-drop-db dev-create-db migrate ## Drop and restore the Authentik PostgreSQL instance to a "fresh install" state. dev-reset: dev-drop-db dev-create-db migrate ## Drop and restore the Authentik PostgreSQL instance to a "fresh install" state.
update-test-mmdb: ## Update test GeoIP and ASN Databases
curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-ASN-Test.mmdb -o ${PWD}/tests/GeoLite2-ASN-Test.mmdb
curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-City-Test.mmdb -o ${PWD}/tests/GeoLite2-City-Test.mmdb
######################### #########################
## API Schema ## API Schema
######################### #########################
@ -150,9 +146,9 @@ gen-client-ts: gen-clean-ts ## Build and install the authentik API for Typescri
--additional-properties=npmVersion=${NPM_VERSION} \ --additional-properties=npmVersion=${NPM_VERSION} \
--git-repo-id authentik \ --git-repo-id authentik \
--git-user-id goauthentik --git-user-id goauthentik
mkdir -p web/node_modules/@goauthentik/api
cd ${PWD}/${GEN_API_TS} && npm link cd ${PWD}/${GEN_API_TS} && npm i
cd ${PWD}/web && npm link @goauthentik/api \cp -rf ${PWD}/${GEN_API_TS}/* web/node_modules/@goauthentik/api
gen-client-py: gen-clean-py ## Build and install the authentik API for Python gen-client-py: gen-clean-py ## Build and install the authentik API for Python
docker run \ docker run \

View File

@ -2,7 +2,7 @@
from os import environ from os import environ
__version__ = "2025.6.3" __version__ = "2025.6.1"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View File

@ -1,67 +0,0 @@
from rest_framework.routers import DefaultRouter as UpstreamDefaultRouter
from rest_framework.viewsets import ViewSet
from rest_framework_nested.routers import NestedMixin
class DefaultRouter(UpstreamDefaultRouter):
include_format_suffixes = False
class NestedRouter(DefaultRouter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.nested_routers = []
class nested:
def __init__(self, parent: "NestedRouter", prefix: str):
self.parent = parent
self.prefix = prefix
self.inner = None
def nested(self, lookup: str, prefix: str, viewset: type[ViewSet]):
if not self.inner:
self.inner = NestedDefaultRouter(self.parent, self.prefix, lookup=lookup)
self.inner.register(prefix, viewset)
return self
@property
def urls(self):
return self.parent.urls
def register(self, prefix, viewset, basename=None):
super().register(prefix, viewset, basename)
nested_router = self.nested(self, prefix)
self.nested_routers.append(nested_router)
return nested_router
def get_urls(self):
urls = super().get_urls()
for nested in self.nested_routers:
if not nested.inner:
continue
urls.extend(nested.inner.urls)
return urls
class NestedDefaultRouter(NestedMixin, DefaultRouter):
...
# def __init__(self, *args, **kwargs):
# self.args = args
# self.kwargs = kwargs
# self.routes = []
# def register(self, *args, **kwargs):
# self.routes.append((args, kwargs))
# @property
# def urls(self):
# class r(NestedMixin, DefaultRouter):
# ...
# router = r(*self.args, **self.kwargs)
# for route_args, route_kwrags in self.routes:
# router.register(*route_args, **route_kwrags)
# return router
root_router = DefaultRouter()

View File

@ -6,15 +6,18 @@ from django.urls import path
from django.urls.resolvers import URLPattern from django.urls.resolvers import URLPattern
from django.views.decorators.cache import cache_page from django.views.decorators.cache import cache_page
from drf_spectacular.views import SpectacularAPIView from drf_spectacular.views import SpectacularAPIView
from rest_framework import routers
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.api.v3.config import ConfigView from authentik.api.v3.config import ConfigView
from authentik.api.v3.routers import root_router
from authentik.api.views import APIBrowserView from authentik.api.views import APIBrowserView
from authentik.lib.utils.reflection import get_apps from authentik.lib.utils.reflection import get_apps
LOGGER = get_logger() LOGGER = get_logger()
router = routers.DefaultRouter()
router.include_format_suffixes = False
_other_urls = [] _other_urls = []
for _authentik_app in get_apps(): for _authentik_app in get_apps():
try: try:
@ -35,7 +38,7 @@ for _authentik_app in get_apps():
if isinstance(url, URLPattern): if isinstance(url, URLPattern):
_other_urls.append(url) _other_urls.append(url)
else: else:
root_router.register(*url) router.register(*url)
LOGGER.debug( LOGGER.debug(
"Mounted API URLs", "Mounted API URLs",
app_name=_authentik_app.name, app_name=_authentik_app.name,
@ -46,7 +49,7 @@ urlpatterns = (
[ [
path("", APIBrowserView.as_view(), name="schema-browser"), path("", APIBrowserView.as_view(), name="schema-browser"),
] ]
+ root_router.urls + router.urls
+ _other_urls + _other_urls
+ [ + [
path("root/config/", ConfigView.as_view(), name="config"), path("root/config/", ConfigView.as_view(), name="config"),

View File

@ -37,7 +37,6 @@ entries:
- attrs: - attrs:
attributes: attributes:
env_null: !Env [bar-baz, null] env_null: !Env [bar-baz, null]
json_parse: !ParseJSON '{"foo": "bar"}'
policy_pk1: policy_pk1:
!Format [ !Format [
"%s-%s", "%s-%s",

View File

@ -35,6 +35,6 @@ def blueprint_tester(file_name: Path) -> Callable:
for blueprint_file in Path("blueprints/").glob("**/*.yaml"): for blueprint_file in Path("blueprints/").glob("**/*.yaml"):
if "local" in str(blueprint_file) or "testing" in str(blueprint_file): if "local" in str(blueprint_file):
continue continue
setattr(TestPackaged, f"test_blueprint_{blueprint_file}", blueprint_tester(blueprint_file)) setattr(TestPackaged, f"test_blueprint_{blueprint_file}", blueprint_tester(blueprint_file))

View File

@ -5,6 +5,7 @@ from collections.abc import Callable
from django.apps import apps from django.apps import apps
from django.test import TestCase from django.test import TestCase
from authentik.blueprints.v1.importer import is_model_allowed
from authentik.lib.models import SerializerModel from authentik.lib.models import SerializerModel
from authentik.providers.oauth2.models import RefreshToken from authentik.providers.oauth2.models import RefreshToken
@ -21,13 +22,10 @@ def serializer_tester_factory(test_model: type[SerializerModel]) -> Callable:
return return
model_class = test_model() model_class = test_model()
self.assertTrue(isinstance(model_class, SerializerModel)) self.assertTrue(isinstance(model_class, SerializerModel))
# Models that have subclasses don't have to have a serializer
if len(test_model.__subclasses__()) > 0:
return
self.assertIsNotNone(model_class.serializer) self.assertIsNotNone(model_class.serializer)
if model_class.serializer.Meta().model == RefreshToken: if model_class.serializer.Meta().model == RefreshToken:
return return
self.assertTrue(issubclass(test_model, model_class.serializer.Meta().model)) self.assertEqual(model_class.serializer.Meta().model, test_model)
return tester return tester
@ -36,6 +34,6 @@ for app in apps.get_app_configs():
if not app.label.startswith("authentik"): if not app.label.startswith("authentik"):
continue continue
for model in app.get_models(): for model in app.get_models():
if not issubclass(model, SerializerModel): if not is_model_allowed(model):
continue continue
setattr(TestModels, f"test_{app.label}_{model.__name__}", serializer_tester_factory(model)) setattr(TestModels, f"test_{app.label}_{model.__name__}", serializer_tester_factory(model))

View File

@ -215,7 +215,6 @@ class TestBlueprintsV1(TransactionTestCase):
}, },
"nested_context": "context-nested-value", "nested_context": "context-nested-value",
"env_null": None, "env_null": None,
"json_parse": {"foo": "bar"},
"at_index_sequence": "foo", "at_index_sequence": "foo",
"at_index_sequence_default": "non existent", "at_index_sequence_default": "non existent",
"at_index_mapping": 2, "at_index_mapping": 2,

View File

@ -6,7 +6,6 @@ from copy import copy
from dataclasses import asdict, dataclass, field, is_dataclass from dataclasses import asdict, dataclass, field, is_dataclass
from enum import Enum from enum import Enum
from functools import reduce from functools import reduce
from json import JSONDecodeError, loads
from operator import ixor from operator import ixor
from os import getenv from os import getenv
from typing import Any, Literal, Union from typing import Any, Literal, Union
@ -292,22 +291,6 @@ class Context(YAMLTag):
return value return value
class ParseJSON(YAMLTag):
"""Parse JSON from context/env/etc value"""
raw: str
def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None:
super().__init__()
self.raw = node.value
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
try:
return loads(self.raw)
except JSONDecodeError as exc:
raise EntryInvalidError.from_entry(exc, entry) from exc
class Format(YAMLTag): class Format(YAMLTag):
"""Format a string""" """Format a string"""
@ -683,7 +666,6 @@ class BlueprintLoader(SafeLoader):
self.add_constructor("!Value", Value) self.add_constructor("!Value", Value)
self.add_constructor("!Index", Index) self.add_constructor("!Index", Index)
self.add_constructor("!AtIndex", AtIndex) self.add_constructor("!AtIndex", AtIndex)
self.add_constructor("!ParseJSON", ParseJSON)
class EntryInvalidError(SentryIgnoredException): class EntryInvalidError(SentryIgnoredException):

View File

@ -1,6 +1,8 @@
"""Authenticator Devices API Views""" """Authenticator Devices API Views"""
from drf_spectacular.utils import extend_schema from django.utils.translation import gettext_lazy as _
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user
from rest_framework.fields import ( from rest_framework.fields import (
BooleanField, BooleanField,
@ -13,7 +15,6 @@ from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.viewsets import ViewSet from rest_framework.viewsets import ViewSet
from authentik.core.api.users import ParamUserSerializer
from authentik.core.api.utils import MetaNameSerializer from authentik.core.api.utils import MetaNameSerializer
from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import EndpointDevice from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import EndpointDevice
from authentik.stages.authenticator import device_classes, devices_for_user from authentik.stages.authenticator import device_classes, devices_for_user
@ -22,7 +23,7 @@ from authentik.stages.authenticator_webauthn.models import WebAuthnDevice
class DeviceSerializer(MetaNameSerializer): class DeviceSerializer(MetaNameSerializer):
"""Serializer for authenticator devices""" """Serializer for Duo authenticator devices"""
pk = CharField() pk = CharField()
name = CharField() name = CharField()
@ -32,27 +33,22 @@ class DeviceSerializer(MetaNameSerializer):
last_updated = DateTimeField(read_only=True) last_updated = DateTimeField(read_only=True)
last_used = DateTimeField(read_only=True, allow_null=True) last_used = DateTimeField(read_only=True, allow_null=True)
extra_description = SerializerMethodField() extra_description = SerializerMethodField()
external_id = SerializerMethodField()
def get_type(self, instance: Device) -> str: def get_type(self, instance: Device) -> str:
"""Get type of device""" """Get type of device"""
return instance._meta.label return instance._meta.label
def get_extra_description(self, instance: Device) -> str | None: def get_extra_description(self, instance: Device) -> str:
"""Get extra description""" """Get extra description"""
if isinstance(instance, WebAuthnDevice): if isinstance(instance, WebAuthnDevice):
return instance.device_type.description if instance.device_type else None return (
instance.device_type.description
if instance.device_type
else _("Extra description not available")
)
if isinstance(instance, EndpointDevice): if isinstance(instance, EndpointDevice):
return instance.data.get("deviceSignals", {}).get("deviceModel") return instance.data.get("deviceSignals", {}).get("deviceModel")
return None return ""
def get_external_id(self, instance: Device) -> str | None:
"""Get external Device ID"""
if isinstance(instance, WebAuthnDevice):
return instance.device_type.aaguid if instance.device_type else None
if isinstance(instance, EndpointDevice):
return instance.data.get("deviceSignals", {}).get("deviceModel")
return None
class DeviceViewSet(ViewSet): class DeviceViewSet(ViewSet):
@ -61,6 +57,7 @@ class DeviceViewSet(ViewSet):
serializer_class = DeviceSerializer serializer_class = DeviceSerializer
permission_classes = [IsAuthenticated] permission_classes = [IsAuthenticated]
@extend_schema(responses={200: DeviceSerializer(many=True)})
def list(self, request: Request) -> Response: def list(self, request: Request) -> Response:
"""Get all devices for current user""" """Get all devices for current user"""
devices = devices_for_user(request.user) devices = devices_for_user(request.user)
@ -82,11 +79,18 @@ class AdminDeviceViewSet(ViewSet):
yield from device_set yield from device_set
@extend_schema( @extend_schema(
parameters=[ParamUserSerializer], parameters=[
OpenApiParameter(
name="user",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.INT,
)
],
responses={200: DeviceSerializer(many=True)}, responses={200: DeviceSerializer(many=True)},
) )
def list(self, request: Request) -> Response: def list(self, request: Request) -> Response:
"""Get all devices for current user""" """Get all devices for current user"""
args = ParamUserSerializer(data=request.query_params) kwargs = {}
args.is_valid(raise_exception=True) if "user" in request.query_params:
return Response(DeviceSerializer(self.get_devices(**args.validated_data), many=True).data) kwargs = {"user": request.query_params["user"]}
return Response(DeviceSerializer(self.get_devices(**kwargs), many=True).data)

View File

@ -90,12 +90,6 @@ from authentik.stages.email.utils import TemplateEmailMessage
LOGGER = get_logger() LOGGER = get_logger()
class ParamUserSerializer(PassiveSerializer):
"""Partial serializer for query parameters to select a user"""
user = PrimaryKeyRelatedField(queryset=User.objects.all().exclude_anonymous(), required=False)
class UserGroupSerializer(ModelSerializer): class UserGroupSerializer(ModelSerializer):
"""Simplified Group Serializer for user's groups""" """Simplified Group Serializer for user's groups"""
@ -392,23 +386,8 @@ class UserViewSet(UsedByMixin, ModelViewSet):
queryset = User.objects.none() queryset = User.objects.none()
ordering = ["username"] ordering = ["username"]
serializer_class = UserSerializer serializer_class = UserSerializer
filterset_class = UsersFilter
search_fields = ["username", "name", "is_active", "email", "uuid", "attributes"] search_fields = ["username", "name", "is_active", "email", "uuid", "attributes"]
filterset_class = UsersFilter
def get_ql_fields(self):
from djangoql.schema import BoolField, StrField
from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
return [
StrField(User, "username"),
StrField(User, "name"),
StrField(User, "email"),
StrField(User, "path"),
BoolField(User, "is_active", nullable=True),
ChoiceSearchField(User, "type"),
JSONSearchField(User, "attributes", suggest_nested=False),
]
def get_queryset(self): def get_queryset(self):
base_qs = User.objects.all().exclude_anonymous() base_qs = User.objects.all().exclude_anonymous()

View File

@ -2,7 +2,6 @@
from typing import Any from typing import Any
from django.db import models
from django.db.models import Model from django.db.models import Model
from drf_spectacular.extensions import OpenApiSerializerFieldExtension from drf_spectacular.extensions import OpenApiSerializerFieldExtension
from drf_spectacular.plumbing import build_basic_type from drf_spectacular.plumbing import build_basic_type
@ -31,27 +30,7 @@ def is_dict(value: Any):
raise ValidationError("Value must be a dictionary, and not have any duplicate keys.") raise ValidationError("Value must be a dictionary, and not have any duplicate keys.")
class JSONDictField(JSONField):
"""JSON Field which only allows dictionaries"""
default_validators = [is_dict]
class JSONExtension(OpenApiSerializerFieldExtension):
"""Generate API Schema for JSON fields as"""
target_class = "authentik.core.api.utils.JSONDictField"
def map_serializer_field(self, auto_schema, direction):
return build_basic_type(OpenApiTypes.OBJECT)
class ModelSerializer(BaseModelSerializer): class ModelSerializer(BaseModelSerializer):
# By default, JSON fields we have are used to store dictionaries
serializer_field_mapping = BaseModelSerializer.serializer_field_mapping.copy()
serializer_field_mapping[models.JSONField] = JSONDictField
def create(self, validated_data): def create(self, validated_data):
instance = super().create(validated_data) instance = super().create(validated_data)
@ -92,6 +71,21 @@ class ModelSerializer(BaseModelSerializer):
return instance return instance
class JSONDictField(JSONField):
"""JSON Field which only allows dictionaries"""
default_validators = [is_dict]
class JSONExtension(OpenApiSerializerFieldExtension):
"""Generate API Schema for JSON fields as"""
target_class = "authentik.core.api.utils.JSONDictField"
def map_serializer_field(self, auto_schema, direction):
return build_basic_type(OpenApiTypes.OBJECT)
class PassiveSerializer(Serializer): class PassiveSerializer(Serializer):
"""Base serializer class which doesn't implement create/update methods""" """Base serializer class which doesn't implement create/update methods"""

View File

@ -13,6 +13,7 @@ class Command(TenantCommand):
parser.add_argument("usernames", nargs="*", type=str) parser.add_argument("usernames", nargs="*", type=str)
def handle_per_tenant(self, **options): def handle_per_tenant(self, **options):
print(options)
new_type = UserTypes(options["type"]) new_type = UserTypes(options["type"])
qs = ( qs = (
User.objects.exclude_anonymous() User.objects.exclude_anonymous()

View File

@ -18,7 +18,7 @@ from django.http import HttpRequest
from django.utils.functional import SimpleLazyObject, cached_property from django.utils.functional import SimpleLazyObject, cached_property
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_cte import CTE, with_cte from django_cte import CTEQuerySet, With
from guardian.conf import settings from guardian.conf import settings
from guardian.mixins import GuardianUserMixin from guardian.mixins import GuardianUserMixin
from model_utils.managers import InheritanceManager from model_utils.managers import InheritanceManager
@ -136,7 +136,7 @@ class AttributesMixin(models.Model):
return instance, False return instance, False
class GroupQuerySet(QuerySet): class GroupQuerySet(CTEQuerySet):
def with_children_recursive(self): def with_children_recursive(self):
"""Recursively get all groups that have the current queryset as parents """Recursively get all groups that have the current queryset as parents
or are indirectly related.""" or are indirectly related."""
@ -165,9 +165,9 @@ class GroupQuerySet(QuerySet):
) )
# Build the recursive query, see above # Build the recursive query, see above
cte = CTE.recursive(make_cte) cte = With.recursive(make_cte)
# Return the result, as a usable queryset for Group. # Return the result, as a usable queryset for Group.
return with_cte(cte, select=cte.join(Group, group_uuid=cte.col.group_uuid)) return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte)
class Group(SerializerModel, AttributesMixin): class Group(SerializerModel, AttributesMixin):
@ -1082,12 +1082,6 @@ class AuthenticatedSession(SerializerModel):
user = models.ForeignKey(User, on_delete=models.CASCADE) user = models.ForeignKey(User, on_delete=models.CASCADE)
@property
def serializer(self) -> type[Serializer]:
from authentik.core.api.authenticated_sessions import AuthenticatedSessionSerializer
return AuthenticatedSessionSerializer
class Meta: class Meta:
verbose_name = _("Authenticated Session") verbose_name = _("Authenticated Session")
verbose_name_plural = _("Authenticated Sessions") verbose_name_plural = _("Authenticated Sessions")

View File

@ -114,7 +114,6 @@ class TestApplicationsAPI(APITestCase):
self.assertJSONEqual( self.assertJSONEqual(
response.content.decode(), response.content.decode(),
{ {
"autocomplete": {},
"pagination": { "pagination": {
"next": 0, "next": 0,
"previous": 0, "previous": 0,
@ -168,7 +167,6 @@ class TestApplicationsAPI(APITestCase):
self.assertJSONEqual( self.assertJSONEqual(
response.content.decode(), response.content.decode(),
{ {
"autocomplete": {},
"pagination": { "pagination": {
"next": 0, "next": 0,
"previous": 0, "previous": 0,

View File

@ -119,17 +119,17 @@ class TestTrimPasswordHistory(TestCase):
[ [
UserPasswordHistory( UserPasswordHistory(
user=self.user, user=self.user,
old_password="hunter1", # nosec old_password="hunter1", # nosec B106
created_at=_now - timedelta(days=3), created_at=_now - timedelta(days=3),
), ),
UserPasswordHistory( UserPasswordHistory(
user=self.user, user=self.user,
old_password="hunter2", # nosec old_password="hunter2", # nosec B106
created_at=_now - timedelta(days=2), created_at=_now - timedelta(days=2),
), ),
UserPasswordHistory( UserPasswordHistory(
user=self.user, user=self.user,
old_password="hunter3", # nosec old_password="hunter3", # nosec B106
created_at=_now, created_at=_now,
), ),
] ]

View File

@ -1,8 +1,10 @@
from hashlib import sha256 from hashlib import sha256
from django.contrib.auth.signals import user_logged_out
from django.db.models import Model from django.db.models import Model
from django.db.models.signals import post_delete, post_save, pre_delete from django.db.models.signals import post_delete, post_save, pre_delete
from django.dispatch import receiver from django.dispatch import receiver
from django.http.request import HttpRequest
from guardian.shortcuts import assign_perm from guardian.shortcuts import assign_perm
from authentik.core.models import ( from authentik.core.models import (
@ -60,6 +62,31 @@ def ssf_providers_post_save(sender: type[Model], instance: SSFProvider, created:
instance.save() instance.save()
@receiver(user_logged_out)
def ssf_user_logged_out_session_revoked(sender, request: HttpRequest, user: User, **_):
"""Session revoked trigger (user logged out)"""
if not request.session or not request.session.session_key or not user:
return
send_ssf_event(
EventTypes.CAEP_SESSION_REVOKED,
{
"initiating_entity": "user",
},
sub_id={
"format": "complex",
"session": {
"format": "opaque",
"id": sha256(request.session.session_key.encode("ascii")).hexdigest(),
},
"user": {
"format": "email",
"email": user.email,
},
},
request=request,
)
@receiver(pre_delete, sender=AuthenticatedSession) @receiver(pre_delete, sender=AuthenticatedSession)
def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSession, **_): def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSession, **_):
"""Session revoked trigger (users' session has been deleted) """Session revoked trigger (users' session has been deleted)

View File

@ -1,12 +0,0 @@
"""Enterprise app config"""
from authentik.enterprise.apps import EnterpriseConfig
class AuthentikEnterpriseSearchConfig(EnterpriseConfig):
"""Enterprise app config"""
name = "authentik.enterprise.search"
label = "authentik_search"
verbose_name = "authentik Enterprise.Search"
default = True

View File

@ -1,128 +0,0 @@
"""DjangoQL search"""
from collections import OrderedDict, defaultdict
from collections.abc import Generator
from django.db import connection
from django.db.models import Model, Q
from djangoql.compat import text_type
from djangoql.schema import StrField
class JSONSearchField(StrField):
"""JSON field for DjangoQL"""
model: Model
def __init__(self, model=None, name=None, nullable=None, suggest_nested=True):
# Set this in the constructor to not clobber the type variable
self.type = "relation"
self.suggest_nested = suggest_nested
super().__init__(model, name, nullable)
def get_lookup(self, path, operator, value):
search = "__".join(path)
op, invert = self.get_operator(operator)
q = Q(**{f"{search}{op}": self.get_lookup_value(value)})
return ~q if invert else q
def json_field_keys(self) -> Generator[tuple[str]]:
with connection.cursor() as cursor:
cursor.execute(
f"""
WITH RECURSIVE "{self.name}_keys" AS (
SELECT
ARRAY[jsonb_object_keys("{self.name}")] AS key_path_array,
"{self.name}" -> jsonb_object_keys("{self.name}") AS value
FROM {self.model._meta.db_table}
WHERE "{self.name}" IS NOT NULL
AND jsonb_typeof("{self.name}") = 'object'
UNION ALL
SELECT
ck.key_path_array || jsonb_object_keys(ck.value),
ck.value -> jsonb_object_keys(ck.value) AS value
FROM "{self.name}_keys" ck
WHERE jsonb_typeof(ck.value) = 'object'
),
unique_paths AS (
SELECT DISTINCT key_path_array
FROM "{self.name}_keys"
)
SELECT key_path_array FROM unique_paths;
""" # nosec
)
return (x[0] for x in cursor.fetchall())
def get_nested_options(self) -> OrderedDict:
"""Get keys of all nested objects to show autocomplete"""
if not self.suggest_nested:
return OrderedDict()
base_model_name = f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}"
def recursive_function(parts: list[str], parent_parts: list[str] | None = None):
if not parent_parts:
parent_parts = []
path = parts.pop(0)
parent_parts.append(path)
relation_key = "_".join(parent_parts)
if len(parts) > 1:
out_dict = {
relation_key: {
parts[0]: {
"type": "relation",
"relation": f"{relation_key}_{parts[0]}",
}
}
}
child_paths = recursive_function(parts.copy(), parent_parts.copy())
child_paths.update(out_dict)
return child_paths
else:
return {relation_key: {parts[0]: {}}}
relation_structure = defaultdict(dict)
for relations in self.json_field_keys():
result = recursive_function([base_model_name] + relations)
for relation_key, value in result.items():
for sub_relation_key, sub_value in value.items():
if not relation_structure[relation_key].get(sub_relation_key, None):
relation_structure[relation_key][sub_relation_key] = sub_value
else:
relation_structure[relation_key][sub_relation_key].update(sub_value)
final_dict = defaultdict(dict)
for key, value in relation_structure.items():
for sub_key, sub_value in value.items():
if not sub_value:
final_dict[key][sub_key] = {
"type": "str",
"nullable": True,
}
else:
final_dict[key][sub_key] = sub_value
return OrderedDict(final_dict)
def relation(self) -> str:
return f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}"
class ChoiceSearchField(StrField):
def __init__(self, model=None, name=None, nullable=None):
super().__init__(model, name, nullable, suggest_options=True)
def get_options(self, search):
result = []
choices = self._field_choices()
if choices:
search = search.lower()
for c in choices:
choice = text_type(c[0])
if search in choice.lower():
result.append(choice)
return result

View File

@ -1,53 +0,0 @@
from rest_framework.response import Response
from authentik.api.pagination import Pagination
from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, QLSearch
class AutocompletePagination(Pagination):
def paginate_queryset(self, queryset, request, view=None):
self.view = view
return super().paginate_queryset(queryset, request, view)
def get_autocomplete(self):
schema = QLSearch().get_schema(self.request, self.view)
introspections = {}
if hasattr(self.view, "get_ql_fields"):
from authentik.enterprise.search.schema import AKQLSchemaSerializer
introspections = AKQLSchemaSerializer().serialize(
schema(self.page.paginator.object_list.model)
)
return introspections
def get_paginated_response(self, data):
previous_page_number = 0
if self.page.has_previous():
previous_page_number = self.page.previous_page_number()
next_page_number = 0
if self.page.has_next():
next_page_number = self.page.next_page_number()
return Response(
{
"pagination": {
"next": next_page_number,
"previous": previous_page_number,
"count": self.page.paginator.count,
"current": self.page.number,
"total_pages": self.page.paginator.num_pages,
"start_index": self.page.start_index(),
"end_index": self.page.end_index(),
},
"results": data,
"autocomplete": self.get_autocomplete(),
}
)
def get_paginated_response_schema(self, schema):
final_schema = super().get_paginated_response_schema(schema)
final_schema["properties"]["autocomplete"] = {
"$ref": f"#/components/schemas/{AUTOCOMPLETE_COMPONENT_NAME}"
}
final_schema["required"].append("autocomplete")
return final_schema

View File

@ -1,81 +0,0 @@
"""DjangoQL search"""
from django.apps import apps
from django.db.models import QuerySet
from djangoql.ast import Name
from djangoql.exceptions import DjangoQLError
from djangoql.queryset import apply_search
from djangoql.schema import DjangoQLSchema
from rest_framework.filters import SearchFilter
from rest_framework.request import Request
from structlog.stdlib import get_logger
from authentik.enterprise.search.fields import JSONSearchField
LOGGER = get_logger()
AUTOCOMPLETE_COMPONENT_NAME = "Autocomplete"
AUTOCOMPLETE_SCHEMA = {
"type": "object",
"additionalProperties": {},
}
class BaseSchema(DjangoQLSchema):
"""Base Schema which deals with JSON Fields"""
def resolve_name(self, name: Name):
model = self.model_label(self.current_model)
root_field = name.parts[0]
field = self.models[model].get(root_field)
# If the query goes into a JSON field, return the root
# field as the JSON field will do the rest
if isinstance(field, JSONSearchField):
# This is a workaround; build_filter will remove the right-most
# entry in the path as that is intended to be the same as the field
# however for JSON that is not the case
if name.parts[-1] != root_field:
name.parts.append(root_field)
return field
return super().resolve_name(name)
# Inherits from SearchFilter to keep the schema correctly
class QLSearch(SearchFilter):
"""rest_framework search filter which uses DjangoQL"""
def __init__(self):
super().__init__()
self._fallback = SearchFilter()
@property
def enabled(self):
return apps.get_app_config("authentik_enterprise").enabled()
def get_search_terms(self, request: Request) -> str:
"""Search terms are set by a ?search=... query parameter,
and may be comma and/or whitespace delimited."""
params = request.query_params.get("search", "")
params = params.replace("\x00", "") # strip null characters
return params
def get_schema(self, request: Request, view) -> BaseSchema:
ql_fields = []
if hasattr(view, "get_ql_fields"):
ql_fields = view.get_ql_fields()
class InlineSchema(BaseSchema):
def get_fields(self, model):
return ql_fields or []
return InlineSchema
def filter_queryset(self, request: Request, queryset: QuerySet, view) -> QuerySet:
search_query = self.get_search_terms(request)
schema = self.get_schema(request, view)
if len(search_query) == 0 or not self.enabled:
return self._fallback.filter_queryset(request, queryset, view)
try:
return apply_search(queryset, search_query, schema=schema)
except DjangoQLError as exc:
LOGGER.debug("Failed to parse search expression", exc=exc)
return self._fallback.filter_queryset(request, queryset, view)

View File

@ -1,29 +0,0 @@
from djangoql.serializers import DjangoQLSchemaSerializer
from drf_spectacular.generators import SchemaGenerator
from authentik.api.schema import create_component
from authentik.enterprise.search.fields import JSONSearchField
from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA
class AKQLSchemaSerializer(DjangoQLSchemaSerializer):
def serialize(self, schema):
serialization = super().serialize(schema)
for _, fields in schema.models.items():
for _, field in fields.items():
if not isinstance(field, JSONSearchField):
continue
serialization["models"].update(field.get_nested_options())
return serialization
def serialize_field(self, field):
result = super().serialize_field(field)
if isinstance(field, JSONSearchField):
result["relation"] = field.relation()
return result
def postprocess_schema_search_autocomplete(result, generator: SchemaGenerator, **kwargs):
create_component(generator, AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA)
return result

View File

@ -1,17 +0,0 @@
SPECTACULAR_SETTINGS = {
"POSTPROCESSING_HOOKS": [
"authentik.api.schema.postprocess_schema_responses",
"authentik.enterprise.search.schema.postprocess_schema_search_autocomplete",
"drf_spectacular.hooks.postprocess_schema_enums",
],
}
REST_FRAMEWORK = {
"DEFAULT_PAGINATION_CLASS": "authentik.enterprise.search.pagination.AutocompletePagination",
"DEFAULT_FILTER_BACKENDS": [
"authentik.enterprise.search.ql.QLSearch",
"authentik.rbac.filters.ObjectFilter",
"django_filters.rest_framework.DjangoFilterBackend",
"rest_framework.filters.OrderingFilter",
],
}

View File

@ -1,78 +0,0 @@
from json import loads
from unittest.mock import PropertyMock, patch
from urllib.parse import urlencode
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.tests.utils import create_test_admin_user
@patch(
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
PropertyMock(return_value=True),
)
class QLTest(APITestCase):
def setUp(self):
self.user = create_test_admin_user()
# ensure we have more than 1 user
create_test_admin_user()
def test_search(self):
"""Test simple search query"""
self.client.force_login(self.user)
query = f'username = "{self.user.username}"'
res = self.client.get(
reverse(
"authentik_api:user-list",
)
+ f"?{urlencode({"search": query})}"
)
self.assertEqual(res.status_code, 200)
content = loads(res.content)
self.assertEqual(content["pagination"]["count"], 1)
self.assertEqual(content["results"][0]["username"], self.user.username)
def test_no_search(self):
"""Ensure works with no search query"""
self.client.force_login(self.user)
res = self.client.get(
reverse(
"authentik_api:user-list",
)
)
self.assertEqual(res.status_code, 200)
content = loads(res.content)
self.assertNotEqual(content["pagination"]["count"], 1)
def test_search_no_ql(self):
"""Test simple search query (no QL)"""
self.client.force_login(self.user)
res = self.client.get(
reverse(
"authentik_api:user-list",
)
+ f"?{urlencode({"search": self.user.username})}"
)
self.assertEqual(res.status_code, 200)
content = loads(res.content)
self.assertEqual(content["pagination"]["count"], 1)
self.assertEqual(content["results"][0]["username"], self.user.username)
def test_search_json(self):
"""Test search query with a JSON attribute"""
self.user.attributes = {"foo": {"bar": "baz"}}
self.user.save()
self.client.force_login(self.user)
query = 'attributes.foo.bar = "baz"'
res = self.client.get(
reverse(
"authentik_api:user-list",
)
+ f"?{urlencode({"search": query})}"
)
self.assertEqual(res.status_code, 200)
content = loads(res.content)
self.assertEqual(content["pagination"]["count"], 1)
self.assertEqual(content["results"][0]["username"], self.user.username)

View File

@ -18,7 +18,6 @@ TENANT_APPS = [
"authentik.enterprise.providers.google_workspace", "authentik.enterprise.providers.google_workspace",
"authentik.enterprise.providers.microsoft_entra", "authentik.enterprise.providers.microsoft_entra",
"authentik.enterprise.providers.ssf", "authentik.enterprise.providers.ssf",
"authentik.enterprise.search",
"authentik.enterprise.stages.authenticator_endpoint_gdtc", "authentik.enterprise.stages.authenticator_endpoint_gdtc",
"authentik.enterprise.stages.mtls", "authentik.enterprise.stages.mtls",
"authentik.enterprise.stages.source", "authentik.enterprise.stages.source",

View File

@ -97,7 +97,6 @@ class SourceStageFinal(StageView):
token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN) token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN)
self.logger.info("Replacing source flow with overridden flow", flow=token.flow.slug) self.logger.info("Replacing source flow with overridden flow", flow=token.flow.slug)
plan = token.plan plan = token.plan
plan.context.update(self.executor.plan.context)
plan.context[PLAN_CONTEXT_IS_RESTORED] = token plan.context[PLAN_CONTEXT_IS_RESTORED] = token
response = plan.to_redirect(self.request, token.flow) response = plan.to_redirect(self.request, token.flow)
token.delete() token.delete()

View File

@ -90,17 +90,14 @@ class TestSourceStage(FlowTestCase):
plan: FlowPlan = session[SESSION_KEY_PLAN] plan: FlowPlan = session[SESSION_KEY_PLAN]
plan.insert_stage(in_memory_stage(SourceStageFinal), index=0) plan.insert_stage(in_memory_stage(SourceStageFinal), index=0)
plan.context[PLAN_CONTEXT_IS_RESTORED] = flow_token plan.context[PLAN_CONTEXT_IS_RESTORED] = flow_token
plan.context["foo"] = "bar"
session[SESSION_KEY_PLAN] = plan session[SESSION_KEY_PLAN] = plan
session.save() session.save()
# Pretend we've just returned from the source # Pretend we've just returned from the source
with self.assertFlowFinishes() as ff: response = self.client.get(
response = self.client.get( reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True )
) self.assertEqual(response.status_code, 200)
self.assertEqual(response.status_code, 200) self.assertStageRedirects(
self.assertStageRedirects( response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug}) )
)
self.assertEqual(ff().context["foo"], "bar")

View File

@ -132,22 +132,6 @@ class EventViewSet(ModelViewSet):
] ]
filterset_class = EventsFilter filterset_class = EventsFilter
def get_ql_fields(self):
from djangoql.schema import DateTimeField, StrField
from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
return [
ChoiceSearchField(Event, "action"),
StrField(Event, "event_uuid"),
StrField(Event, "app", suggest_options=True),
StrField(Event, "client_ip"),
JSONSearchField(Event, "user", suggest_nested=False),
JSONSearchField(Event, "brand", suggest_nested=False),
JSONSearchField(Event, "context", suggest_nested=False),
DateTimeField(Event, "created", suggest_options=True),
]
@extend_schema( @extend_schema(
methods=["GET"], methods=["GET"],
responses={200: EventTopPerUserSerializer(many=True)}, responses={200: EventTopPerUserSerializer(many=True)},

View File

@ -11,7 +11,7 @@ from authentik.events.models import NotificationRule
class NotificationRuleSerializer(ModelSerializer): class NotificationRuleSerializer(ModelSerializer):
"""NotificationRule Serializer""" """NotificationRule Serializer"""
destination_group_obj = GroupSerializer(read_only=True, source="destination_group") group_obj = GroupSerializer(read_only=True, source="group")
class Meta: class Meta:
model = NotificationRule model = NotificationRule
@ -20,9 +20,8 @@ class NotificationRuleSerializer(ModelSerializer):
"name", "name",
"transports", "transports",
"severity", "severity",
"destination_group", "group",
"destination_group_obj", "group_obj",
"destination_event_user",
] ]
@ -31,6 +30,6 @@ class NotificationRuleViewSet(UsedByMixin, ModelViewSet):
queryset = NotificationRule.objects.all() queryset = NotificationRule.objects.all()
serializer_class = NotificationRuleSerializer serializer_class = NotificationRuleSerializer
filterset_fields = ["name", "severity", "destination_group__name"] filterset_fields = ["name", "severity", "group__name"]
ordering = ["name"] ordering = ["name"]
search_fields = ["name", "destination_group__name"] search_fields = ["name", "group__name"]

View File

@ -15,13 +15,13 @@ class MMDBContextProcessor(EventContextProcessor):
self.reader: Reader | None = None self.reader: Reader | None = None
self._last_mtime: float = 0.0 self._last_mtime: float = 0.0
self.logger = get_logger() self.logger = get_logger()
self.load() self.open()
def path(self) -> str | None: def path(self) -> str | None:
"""Get the path to the MMDB file to load""" """Get the path to the MMDB file to load"""
raise NotImplementedError raise NotImplementedError
def load(self): def open(self):
"""Get GeoIP Reader, if configured, otherwise none""" """Get GeoIP Reader, if configured, otherwise none"""
path = self.path() path = self.path()
if path == "" or not path: if path == "" or not path:
@ -44,7 +44,7 @@ class MMDBContextProcessor(EventContextProcessor):
diff = self._last_mtime < mtime diff = self._last_mtime < mtime
if diff > 0: if diff > 0:
self.logger.info("Found new MMDB Database, reopening", diff=diff, path=path) self.logger.info("Found new MMDB Database, reopening", diff=diff, path=path)
self.load() self.open()
except OSError as exc: except OSError as exc:
self.logger.warning("Failed to check MMDB age", exc=exc) self.logger.warning("Failed to check MMDB age", exc=exc)

View File

@ -19,7 +19,7 @@ from authentik.blueprints.v1.importer import excluded_models
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.events.models import Event, EventAction, Notification from authentik.events.models import Event, EventAction, Notification
from authentik.events.utils import model_to_dict from authentik.events.utils import model_to_dict
from authentik.lib.sentry import should_ignore_exception from authentik.lib.sentry import before_send
from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.errors import exception_to_string
from authentik.stages.authenticator_static.models import StaticToken from authentik.stages.authenticator_static.models import StaticToken
@ -173,7 +173,7 @@ class AuditMiddleware:
message=exception_to_string(exception), message=exception_to_string(exception),
) )
thread.run() thread.run()
elif not should_ignore_exception(exception): elif before_send({}, {"exc_info": (None, exception, None)}) is not None:
thread = EventNewThread( thread = EventNewThread(
EventAction.SYSTEM_EXCEPTION, EventAction.SYSTEM_EXCEPTION,
request, request,

View File

@ -1,26 +0,0 @@
# Generated by Django 5.1.11 on 2025-06-16 23:21
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_events", "0009_remove_notificationtransport_webhook_mapping_and_more"),
]
operations = [
migrations.RenameField(
model_name="notificationrule",
old_name="group",
new_name="destination_group",
),
migrations.AddField(
model_name="notificationrule",
name="destination_event_user",
field=models.BooleanField(
default=False,
help_text="When enabled, notification will be sent to user the user that triggered the event.When destination_group is configured, notification is sent to both.",
),
),
]

View File

@ -1,12 +1,10 @@
"""authentik events models""" """authentik events models"""
from collections.abc import Generator
from datetime import timedelta from datetime import timedelta
from difflib import get_close_matches from difflib import get_close_matches
from functools import lru_cache from functools import lru_cache
from inspect import currentframe from inspect import currentframe
from smtplib import SMTPException from smtplib import SMTPException
from typing import Any
from uuid import uuid4 from uuid import uuid4
from django.apps import apps from django.apps import apps
@ -193,32 +191,17 @@ class Event(SerializerModel, ExpiringModel):
brand: Brand = request.brand brand: Brand = request.brand
self.brand = sanitize_dict(model_to_dict(brand)) self.brand = sanitize_dict(model_to_dict(brand))
if hasattr(request, "user"): if hasattr(request, "user"):
self.user = get_user(request.user) original_user = None
if hasattr(request, "session"):
original_user = request.session.get(SESSION_KEY_IMPERSONATE_ORIGINAL_USER, None)
self.user = get_user(request.user, original_user)
if user: if user:
self.user = get_user(user) self.user = get_user(user)
# Check if we're currently impersonating, and add that user
if hasattr(request, "session"): if hasattr(request, "session"):
from authentik.flows.views.executor import SESSION_KEY_PLAN
# Check if we're currently impersonating, and add that user
if SESSION_KEY_IMPERSONATE_ORIGINAL_USER in request.session: if SESSION_KEY_IMPERSONATE_ORIGINAL_USER in request.session:
self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER]) self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER])
self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER]) self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER])
# Special case for events that happen during a flow, the user might not be authenticated
# yet but is a pending user instead
if SESSION_KEY_PLAN in request.session:
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
plan: FlowPlan = request.session[SESSION_KEY_PLAN]
pending_user = plan.context.get(PLAN_CONTEXT_PENDING_USER, None)
# Only save `authenticated_as` if there's a different pending user in the flow
# than the user that is authenticated
if pending_user and (
(pending_user.pk and pending_user.pk != self.user.get("pk"))
or (not pending_user.pk)
):
orig_user = self.user.copy()
self.user = {"authenticated_as": orig_user, **get_user(pending_user)}
# User 255.255.255.255 as fallback if IP cannot be determined # User 255.255.255.255 as fallback if IP cannot be determined
self.client_ip = ClientIPMiddleware.get_client_ip(request) self.client_ip = ClientIPMiddleware.get_client_ip(request)
# Enrich event data # Enrich event data
@ -564,7 +547,7 @@ class NotificationRule(SerializerModel, PolicyBindingModel):
default=NotificationSeverity.NOTICE, default=NotificationSeverity.NOTICE,
help_text=_("Controls which severity level the created notifications will have."), help_text=_("Controls which severity level the created notifications will have."),
) )
destination_group = models.ForeignKey( group = models.ForeignKey(
Group, Group,
help_text=_( help_text=_(
"Define which group of users this notification should be sent and shown to. " "Define which group of users this notification should be sent and shown to. "
@ -574,19 +557,6 @@ class NotificationRule(SerializerModel, PolicyBindingModel):
blank=True, blank=True,
on_delete=models.SET_NULL, on_delete=models.SET_NULL,
) )
destination_event_user = models.BooleanField(
default=False,
help_text=_(
"When enabled, notification will be sent to user the user that triggered the event."
"When destination_group is configured, notification is sent to both."
),
)
def destination_users(self, event: Event) -> Generator[User, Any]:
if self.destination_event_user and event.user.get("pk"):
yield User(pk=event.user.get("pk"))
if self.destination_group:
yield from self.destination_group.users.all()
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> type[Serializer]:

View File

@ -68,10 +68,14 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
if not result.passing: if not result.passing:
return return
if not trigger.group:
LOGGER.debug("e(trigger): trigger has no group", trigger=trigger)
return
LOGGER.debug("e(trigger): event trigger matched", trigger=trigger) LOGGER.debug("e(trigger): event trigger matched", trigger=trigger)
# Create the notification objects # Create the notification objects
for transport in trigger.transports.all(): for transport in trigger.transports.all():
for user in trigger.destination_users(event): for user in trigger.group.users.all():
LOGGER.debug("created notification") LOGGER.debug("created notification")
notification_transport.apply_async( notification_transport.apply_async(
args=[ args=[

View File

@ -2,9 +2,7 @@
from django.test import TestCase from django.test import TestCase
from authentik.events.context_processors.base import get_context_processors
from authentik.events.context_processors.geoip import GeoIPContextProcessor from authentik.events.context_processors.geoip import GeoIPContextProcessor
from authentik.events.models import Event, EventAction
class TestGeoIP(TestCase): class TestGeoIP(TestCase):
@ -15,7 +13,8 @@ class TestGeoIP(TestCase):
def test_simple(self): def test_simple(self):
"""Test simple city wrapper""" """Test simple city wrapper"""
# IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json # IPs from
# https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
self.assertEqual( self.assertEqual(
self.reader.city_dict("2.125.160.216"), self.reader.city_dict("2.125.160.216"),
{ {
@ -26,12 +25,3 @@ class TestGeoIP(TestCase):
"long": -1.25, "long": -1.25,
}, },
) )
def test_special_chars(self):
"""Test city name with special characters"""
# IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
event = Event.new(EventAction.LOGIN)
event.client_ip = "89.160.20.112"
for processor in get_context_processors():
processor.enrich_event(event)
event.save()

View File

@ -8,11 +8,9 @@ from django.views.debug import SafeExceptionReporterFilter
from guardian.shortcuts import get_anonymous_user from guardian.shortcuts import get_anonymous_user
from authentik.brands.models import Brand from authentik.brands.models import Brand
from authentik.core.models import Group, User from authentik.core.models import Group
from authentik.core.tests.utils import create_test_user
from authentik.events.models import Event from authentik.events.models import Event
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan from authentik.flows.views.executor import QS_QUERY
from authentik.flows.views.executor import QS_QUERY, SESSION_KEY_PLAN
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy from authentik.policies.dummy.models import DummyPolicy
@ -118,92 +116,3 @@ class TestEvents(TestCase):
"pk": brand.pk.hex, "pk": brand.pk.hex,
}, },
) )
def test_from_http_flow_pending_user(self):
"""Test request from flow request with a pending user"""
user = create_test_user()
session = self.client.session
plan = FlowPlan(generate_id())
plan.context[PLAN_CONTEXT_PENDING_USER] = user
session[SESSION_KEY_PLAN] = plan
session.save()
request = self.factory.get("/")
request.session = session
request.user = user
event = Event.new("unittest").from_http(request)
self.assertEqual(
event.user,
{
"email": user.email,
"pk": user.pk,
"username": user.username,
},
)
def test_from_http_flow_pending_user_anon(self):
"""Test request from flow request with a pending user"""
user = create_test_user()
anon = get_anonymous_user()
session = self.client.session
plan = FlowPlan(generate_id())
plan.context[PLAN_CONTEXT_PENDING_USER] = user
session[SESSION_KEY_PLAN] = plan
session.save()
request = self.factory.get("/")
request.session = session
request.user = anon
event = Event.new("unittest").from_http(request)
self.assertEqual(
event.user,
{
"authenticated_as": {
"pk": anon.pk,
"is_anonymous": True,
"username": "AnonymousUser",
"email": "",
},
"email": user.email,
"pk": user.pk,
"username": user.username,
},
)
def test_from_http_flow_pending_user_fake(self):
"""Test request from flow request with a pending user"""
user = User(
username=generate_id(),
email=generate_id(),
)
anon = get_anonymous_user()
session = self.client.session
plan = FlowPlan(generate_id())
plan.context[PLAN_CONTEXT_PENDING_USER] = user
session[SESSION_KEY_PLAN] = plan
session.save()
request = self.factory.get("/")
request.session = session
request.user = anon
event = Event.new("unittest").from_http(request)
self.assertEqual(
event.user,
{
"authenticated_as": {
"pk": anon.pk,
"is_anonymous": True,
"username": "AnonymousUser",
"email": "",
},
"email": user.email,
"pk": user.pk,
"username": user.username,
},
)

View File

@ -6,7 +6,6 @@ from django.urls import reverse
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_user
from authentik.events.models import ( from authentik.events.models import (
Event, Event,
EventAction, EventAction,
@ -35,7 +34,7 @@ class TestEventsNotifications(APITestCase):
def test_trigger_empty(self): def test_trigger_empty(self):
"""Test trigger without any policies attached""" """Test trigger without any policies attached"""
transport = NotificationTransport.objects.create(name=generate_id()) transport = NotificationTransport.objects.create(name=generate_id())
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport) trigger.transports.add(transport)
trigger.save() trigger.save()
@ -47,7 +46,7 @@ class TestEventsNotifications(APITestCase):
def test_trigger_single(self): def test_trigger_single(self):
"""Test simple transport triggering""" """Test simple transport triggering"""
transport = NotificationTransport.objects.create(name=generate_id()) transport = NotificationTransport.objects.create(name=generate_id())
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport) trigger.transports.add(transport)
trigger.save() trigger.save()
matcher = EventMatcherPolicy.objects.create( matcher = EventMatcherPolicy.objects.create(
@ -60,25 +59,6 @@ class TestEventsNotifications(APITestCase):
Event.new(EventAction.CUSTOM_PREFIX).save() Event.new(EventAction.CUSTOM_PREFIX).save()
self.assertEqual(execute_mock.call_count, 1) self.assertEqual(execute_mock.call_count, 1)
def test_trigger_event_user(self):
"""Test trigger with event user"""
user = create_test_user()
transport = NotificationTransport.objects.create(name=generate_id())
trigger = NotificationRule.objects.create(name=generate_id(), destination_event_user=True)
trigger.transports.add(transport)
trigger.save()
matcher = EventMatcherPolicy.objects.create(
name="matcher", action=EventAction.CUSTOM_PREFIX
)
PolicyBinding.objects.create(target=trigger, policy=matcher, order=0)
execute_mock = MagicMock()
with patch("authentik.events.models.NotificationTransport.send", execute_mock):
Event.new(EventAction.CUSTOM_PREFIX).set_user(user).save()
self.assertEqual(execute_mock.call_count, 1)
notification: Notification = execute_mock.call_args[0][0]
self.assertEqual(notification.user, user)
def test_trigger_no_group(self): def test_trigger_no_group(self):
"""Test trigger without group""" """Test trigger without group"""
trigger = NotificationRule.objects.create(name=generate_id()) trigger = NotificationRule.objects.create(name=generate_id())
@ -96,7 +76,7 @@ class TestEventsNotifications(APITestCase):
"""Test Policy error which would cause recursion""" """Test Policy error which would cause recursion"""
transport = NotificationTransport.objects.create(name=generate_id()) transport = NotificationTransport.objects.create(name=generate_id())
NotificationRule.objects.filter(name__startswith="default").delete() NotificationRule.objects.filter(name__startswith="default").delete()
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport) trigger.transports.add(transport)
trigger.save() trigger.save()
matcher = EventMatcherPolicy.objects.create( matcher = EventMatcherPolicy.objects.create(
@ -119,7 +99,7 @@ class TestEventsNotifications(APITestCase):
transport = NotificationTransport.objects.create(name=generate_id(), send_once=True) transport = NotificationTransport.objects.create(name=generate_id(), send_once=True)
NotificationRule.objects.filter(name__startswith="default").delete() NotificationRule.objects.filter(name__startswith="default").delete()
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport) trigger.transports.add(transport)
trigger.save() trigger.save()
matcher = EventMatcherPolicy.objects.create( matcher = EventMatcherPolicy.objects.create(
@ -143,7 +123,7 @@ class TestEventsNotifications(APITestCase):
name=generate_id(), webhook_mapping_body=mapping, mode=TransportMode.LOCAL name=generate_id(), webhook_mapping_body=mapping, mode=TransportMode.LOCAL
) )
NotificationRule.objects.filter(name__startswith="default").delete() NotificationRule.objects.filter(name__startswith="default").delete()
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group) trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport) trigger.transports.add(transport)
matcher = EventMatcherPolicy.objects.create( matcher = EventMatcherPolicy.objects.create(
name="matcher", action=EventAction.CUSTOM_PREFIX name="matcher", action=EventAction.CUSTOM_PREFIX

View File

@ -74,8 +74,8 @@ def model_to_dict(model: Model) -> dict[str, Any]:
} }
def get_user(user: User | AnonymousUser) -> dict[str, Any]: def get_user(user: User | AnonymousUser, original_user: User | None = None) -> dict[str, Any]:
"""Convert user object to dictionary""" """Convert user object to dictionary, optionally including the original user"""
if isinstance(user, AnonymousUser): if isinstance(user, AnonymousUser):
try: try:
user = get_anonymous_user() user = get_anonymous_user()
@ -88,6 +88,10 @@ def get_user(user: User | AnonymousUser) -> dict[str, Any]:
} }
if user.username == settings.ANONYMOUS_USER_NAME: if user.username == settings.ANONYMOUS_USER_NAME:
user_data["is_anonymous"] = True user_data["is_anonymous"] = True
if original_user:
original_data = get_user(original_user)
original_data["on_behalf_of"] = user_data
return original_data
return user_data return user_data

View File

@ -4,10 +4,8 @@ from unittest.mock import MagicMock, PropertyMock, patch
from urllib.parse import urlencode from urllib.parse import urlencode
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.test import override_settings
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.urls import reverse from django.urls import reverse
from rest_framework.exceptions import ParseError
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_flow, create_test_user from authentik.core.tests.utils import create_test_flow, create_test_user
@ -650,25 +648,3 @@ class TestFlowExecutor(FlowTestCase):
self.assertStageResponse(response, flow, component="ak-stage-identification") self.assertStageResponse(response, flow, component="ak-stage-identification")
response = self.client.post(exec_url, {"uid_field": user_other.username}, follow=True) response = self.client.post(exec_url, {"uid_field": user_other.username}, follow=True)
self.assertStageResponse(response, flow, component="ak-stage-access-denied") self.assertStageResponse(response, flow, component="ak-stage-access-denied")
@patch(
"authentik.flows.views.executor.to_stage_response",
TO_STAGE_RESPONSE_MOCK,
)
def test_invalid_json(self):
"""Test invalid JSON body"""
flow = create_test_flow()
FlowStageBinding.objects.create(
target=flow, stage=DummyStage.objects.create(name=generate_id()), order=0
)
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
with override_settings(TEST=False, DEBUG=False):
self.client.logout()
response = self.client.post(url, data="{", content_type="application/json")
self.assertEqual(response.status_code, 200)
with self.assertRaises(ParseError):
self.client.logout()
response = self.client.post(url, data="{", content_type="application/json")
self.assertEqual(response.status_code, 200)

View File

@ -55,7 +55,7 @@ from authentik.flows.planner import (
FlowPlanner, FlowPlanner,
) )
from authentik.flows.stage import AccessDeniedStage, StageView from authentik.flows.stage import AccessDeniedStage, StageView
from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception from authentik.lib.sentry import SentryIgnoredException
from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.reflection import all_subclasses, class_to_path from authentik.lib.utils.reflection import all_subclasses, class_to_path
from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs
@ -234,13 +234,12 @@ class FlowExecutorView(APIView):
"""Handle exception in stage execution""" """Handle exception in stage execution"""
if settings.DEBUG or settings.TEST: if settings.DEBUG or settings.TEST:
raise exc raise exc
capture_exception(exc)
self._logger.warning(exc) self._logger.warning(exc)
if not should_ignore_exception(exc): Event.new(
capture_exception(exc) action=EventAction.SYSTEM_EXCEPTION,
Event.new( message=exception_to_string(exc),
action=EventAction.SYSTEM_EXCEPTION, ).from_http(self.request)
message=exception_to_string(exc),
).from_http(self.request)
challenge = FlowErrorChallenge(self.request, exc) challenge = FlowErrorChallenge(self.request, exc)
challenge.is_valid(raise_exception=True) challenge.is_valid(raise_exception=True)
return to_stage_response(self.request, HttpChallengeResponse(challenge)) return to_stage_response(self.request, HttpChallengeResponse(challenge))

View File

@ -14,7 +14,6 @@ from django_redis.exceptions import ConnectionInterrupted
from docker.errors import DockerException from docker.errors import DockerException
from h11 import LocalProtocolError from h11 import LocalProtocolError
from ldap3.core.exceptions import LDAPException from ldap3.core.exceptions import LDAPException
from psycopg.errors import Error
from redis.exceptions import ConnectionError as RedisConnectionError from redis.exceptions import ConnectionError as RedisConnectionError
from redis.exceptions import RedisError, ResponseError from redis.exceptions import RedisError, ResponseError
from rest_framework.exceptions import APIException from rest_framework.exceptions import APIException
@ -45,49 +44,6 @@ class SentryIgnoredException(Exception):
"""Base Class for all errors that are suppressed, and not sent to sentry.""" """Base Class for all errors that are suppressed, and not sent to sentry."""
ignored_classes = (
# Inbuilt types
KeyboardInterrupt,
ConnectionResetError,
OSError,
PermissionError,
# Django Errors
Error,
ImproperlyConfigured,
DatabaseError,
OperationalError,
InternalError,
ProgrammingError,
SuspiciousOperation,
ValidationError,
# Redis errors
RedisConnectionError,
ConnectionInterrupted,
RedisError,
ResponseError,
# websocket errors
ChannelFull,
WebSocketException,
LocalProtocolError,
# rest_framework error
APIException,
# celery errors
WorkerLostError,
CeleryError,
SoftTimeLimitExceeded,
# custom baseclass
SentryIgnoredException,
# ldap errors
LDAPException,
# Docker errors
DockerException,
# End-user errors
Http404,
# AsyncIO
CancelledError,
)
class SentryTransport(HttpTransport): class SentryTransport(HttpTransport):
"""Custom sentry transport with custom user-agent""" """Custom sentry transport with custom user-agent"""
@ -145,17 +101,56 @@ def traces_sampler(sampling_context: dict) -> float:
return float(CONFIG.get("error_reporting.sample_rate", 0.1)) return float(CONFIG.get("error_reporting.sample_rate", 0.1))
def should_ignore_exception(exc: Exception) -> bool:
"""Check if an exception should be dropped"""
return isinstance(exc, ignored_classes)
def before_send(event: dict, hint: dict) -> dict | None: def before_send(event: dict, hint: dict) -> dict | None:
"""Check if error is database error, and ignore if so""" """Check if error is database error, and ignore if so"""
from psycopg.errors import Error
ignored_classes = (
# Inbuilt types
KeyboardInterrupt,
ConnectionResetError,
OSError,
PermissionError,
# Django Errors
Error,
ImproperlyConfigured,
DatabaseError,
OperationalError,
InternalError,
ProgrammingError,
SuspiciousOperation,
ValidationError,
# Redis errors
RedisConnectionError,
ConnectionInterrupted,
RedisError,
ResponseError,
# websocket errors
ChannelFull,
WebSocketException,
LocalProtocolError,
# rest_framework error
APIException,
# celery errors
WorkerLostError,
CeleryError,
SoftTimeLimitExceeded,
# custom baseclass
SentryIgnoredException,
# ldap errors
LDAPException,
# Docker errors
DockerException,
# End-user errors
Http404,
# AsyncIO
CancelledError,
)
exc_value = None exc_value = None
if "exc_info" in hint: if "exc_info" in hint:
_, exc_value, _ = hint["exc_info"] _, exc_value, _ = hint["exc_info"]
if should_ignore_exception(exc_value): if isinstance(exc_value, ignored_classes):
LOGGER.debug("dropping exception", exc=exc_value) LOGGER.debug("dropping exception", exc=exc_value)
return None return None
if "logger" in event: if "logger" in event:

View File

@ -2,7 +2,7 @@
from django.test import TestCase from django.test import TestCase
from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception from authentik.lib.sentry import SentryIgnoredException, before_send
class TestSentry(TestCase): class TestSentry(TestCase):
@ -10,8 +10,8 @@ class TestSentry(TestCase):
def test_error_not_sent(self): def test_error_not_sent(self):
"""Test SentryIgnoredError not sent""" """Test SentryIgnoredError not sent"""
self.assertTrue(should_ignore_exception(SentryIgnoredException())) self.assertIsNone(before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)}))
def test_error_sent(self): def test_error_sent(self):
"""Test error sent""" """Test error sent"""
self.assertFalse(should_ignore_exception(ValueError())) self.assertEqual({}, before_send({}, {"exc_info": (0, ValueError(), 0)}))

View File

@ -1,13 +1,15 @@
"""authentik outpost signals""" """authentik outpost signals"""
from django.contrib.auth.signals import user_logged_out
from django.core.cache import cache from django.core.cache import cache
from django.db.models import Model from django.db.models import Model
from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save
from django.dispatch import receiver from django.dispatch import receiver
from django.http import HttpRequest
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.brands.models import Brand from authentik.brands.models import Brand
from authentik.core.models import AuthenticatedSession, Provider from authentik.core.models import AuthenticatedSession, Provider, User
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.lib.utils.reflection import class_to_path from authentik.lib.utils.reflection import class_to_path
from authentik.outposts.models import Outpost, OutpostServiceConnection from authentik.outposts.models import Outpost, OutpostServiceConnection
@ -80,6 +82,14 @@ def pre_delete_cleanup(sender, instance: Outpost, **_):
outpost_controller.delay(instance.pk.hex, action="down", from_cache=True) outpost_controller.delay(instance.pk.hex, action="down", from_cache=True)
@receiver(user_logged_out)
def logout_revoke_direct(sender: type[User], request: HttpRequest, **_):
"""Catch logout by direct logout and forward to providers"""
if not request.session or not request.session.session_key:
return
outpost_session_end.delay(request.session.session_key)
@receiver(pre_delete, sender=AuthenticatedSession) @receiver(pre_delete, sender=AuthenticatedSession)
def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_): def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
"""Catch logout by expiring sessions being deleted""" """Catch logout by expiring sessions being deleted"""

View File

@ -15,6 +15,7 @@ class OAuth2Error(SentryIgnoredException):
error: str error: str
description: str description: str
cause: str | None = None
def create_dict(self): def create_dict(self):
"""Return error as dict for JSON Rendering""" """Return error as dict for JSON Rendering"""
@ -34,6 +35,10 @@ class OAuth2Error(SentryIgnoredException):
**kwargs, **kwargs,
) )
def with_cause(self, cause: str):
self.cause = cause
return self
class RedirectUriError(OAuth2Error): class RedirectUriError(OAuth2Error):
"""The request fails due to a missing, invalid, or mismatching """The request fails due to a missing, invalid, or mismatching

View File

@ -1,10 +1,23 @@
from django.contrib.auth.signals import user_logged_out
from django.db.models.signals import post_save, pre_delete from django.db.models.signals import post_save, pre_delete
from django.dispatch import receiver from django.dispatch import receiver
from django.http import HttpRequest
from authentik.core.models import AuthenticatedSession, User from authentik.core.models import AuthenticatedSession, User
from authentik.providers.oauth2.models import AccessToken, DeviceToken, RefreshToken from authentik.providers.oauth2.models import AccessToken, DeviceToken, RefreshToken
@receiver(user_logged_out)
def user_logged_out_oauth_tokens_removal(sender, request: HttpRequest, user: User, **_):
"""Revoke tokens upon user logout"""
if not request.session or not request.session.session_key:
return
AccessToken.objects.filter(
user=user,
session__session__session_key=request.session.session_key,
).delete()
@receiver(pre_delete, sender=AuthenticatedSession) @receiver(pre_delete, sender=AuthenticatedSession)
def user_session_deleted_oauth_tokens_removal(sender, instance: AuthenticatedSession, **_): def user_session_deleted_oauth_tokens_removal(sender, instance: AuthenticatedSession, **_):
"""Revoke tokens upon user logout""" """Revoke tokens upon user logout"""

View File

@ -12,7 +12,7 @@ from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.lib.utils.time import timedelta_from_string from authentik.lib.utils.time import timedelta_from_string
from authentik.providers.oauth2.constants import TOKEN_TYPE from authentik.providers.oauth2.constants import SCOPE_OFFLINE_ACCESS, SCOPE_OPENID, TOKEN_TYPE
from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError from authentik.providers.oauth2.errors import AuthorizeError, ClientIdError, RedirectUriError
from authentik.providers.oauth2.models import ( from authentik.providers.oauth2.models import (
AccessToken, AccessToken,
@ -43,7 +43,7 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")], redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
) )
with self.assertRaises(AuthorizeError): with self.assertRaises(AuthorizeError) as cm:
request = self.factory.get( request = self.factory.get(
"/", "/",
data={ data={
@ -53,6 +53,7 @@ class TestAuthorize(OAuthTestCase):
}, },
) )
OAuthAuthorizationParams.from_request(request) OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.error, "unsupported_response_type")
def test_invalid_client_id(self): def test_invalid_client_id(self):
"""Test invalid client ID""" """Test invalid client ID"""
@ -68,7 +69,7 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")], redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid/Foo")],
) )
with self.assertRaises(AuthorizeError): with self.assertRaises(AuthorizeError) as cm:
request = self.factory.get( request = self.factory.get(
"/", "/",
data={ data={
@ -79,19 +80,30 @@ class TestAuthorize(OAuthTestCase):
}, },
) )
OAuthAuthorizationParams.from_request(request) OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.error, "request_not_supported")
def test_invalid_redirect_uri(self): def test_invalid_redirect_uri_missing(self):
"""test missing/invalid redirect URI""" """test missing redirect URI"""
OAuth2Provider.objects.create( OAuth2Provider.objects.create(
name=generate_id(), name=generate_id(),
client_id="test", client_id="test",
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")], redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
) )
with self.assertRaises(RedirectUriError): with self.assertRaises(RedirectUriError) as cm:
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"}) request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request) OAuthAuthorizationParams.from_request(request)
with self.assertRaises(RedirectUriError): self.assertEqual(cm.exception.cause, "redirect_uri_missing")
def test_invalid_redirect_uri(self):
"""test invalid redirect URI"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid")],
)
with self.assertRaises(RedirectUriError) as cm:
request = self.factory.get( request = self.factory.get(
"/", "/",
data={ data={
@ -101,6 +113,7 @@ class TestAuthorize(OAuthTestCase):
}, },
) )
OAuthAuthorizationParams.from_request(request) OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
def test_blocked_redirect_uri(self): def test_blocked_redirect_uri(self):
"""test missing/invalid redirect URI""" """test missing/invalid redirect URI"""
@ -108,9 +121,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(), name=generate_id(),
client_id="test", client_id="test",
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:local.invalid")], redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "data:localhost")],
) )
with self.assertRaises(RedirectUriError): with self.assertRaises(RedirectUriError) as cm:
request = self.factory.get( request = self.factory.get(
"/", "/",
data={ data={
@ -120,6 +133,7 @@ class TestAuthorize(OAuthTestCase):
}, },
) )
OAuthAuthorizationParams.from_request(request) OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.cause, "redirect_uri_forbidden_scheme")
def test_invalid_redirect_uri_empty(self): def test_invalid_redirect_uri_empty(self):
"""test missing/invalid redirect URI""" """test missing/invalid redirect URI"""
@ -129,9 +143,6 @@ class TestAuthorize(OAuthTestCase):
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris=[], redirect_uris=[],
) )
with self.assertRaises(RedirectUriError):
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
request = self.factory.get( request = self.factory.get(
"/", "/",
data={ data={
@ -150,12 +161,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(), name=generate_id(),
client_id="test", client_id="test",
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://local.invalid?")], redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "http://local.invalid?")],
) )
with self.assertRaises(RedirectUriError): with self.assertRaises(RedirectUriError) as cm:
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
with self.assertRaises(RedirectUriError):
request = self.factory.get( request = self.factory.get(
"/", "/",
data={ data={
@ -165,6 +173,7 @@ class TestAuthorize(OAuthTestCase):
}, },
) )
OAuthAuthorizationParams.from_request(request) OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
def test_redirect_uri_invalid_regex(self): def test_redirect_uri_invalid_regex(self):
"""test missing/invalid redirect URI (invalid regex)""" """test missing/invalid redirect URI (invalid regex)"""
@ -172,12 +181,9 @@ class TestAuthorize(OAuthTestCase):
name=generate_id(), name=generate_id(),
client_id="test", client_id="test",
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "+")], redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, "+")],
) )
with self.assertRaises(RedirectUriError): with self.assertRaises(RedirectUriError) as cm:
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
with self.assertRaises(RedirectUriError):
request = self.factory.get( request = self.factory.get(
"/", "/",
data={ data={
@ -187,23 +193,22 @@ class TestAuthorize(OAuthTestCase):
}, },
) )
OAuthAuthorizationParams.from_request(request) OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.cause, "redirect_uri_no_match")
def test_empty_redirect_uri(self): def test_redirect_uri_regex(self):
"""test empty redirect URI (configure in provider)""" """test valid redirect URI (regex)"""
OAuth2Provider.objects.create( OAuth2Provider.objects.create(
name=generate_id(), name=generate_id(),
client_id="test", client_id="test",
authorization_flow=create_test_flow(), authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.REGEX, ".+")],
) )
with self.assertRaises(RedirectUriError):
request = self.factory.get("/", data={"response_type": "code", "client_id": "test"})
OAuthAuthorizationParams.from_request(request)
request = self.factory.get( request = self.factory.get(
"/", "/",
data={ data={
"response_type": "code", "response_type": "code",
"client_id": "test", "client_id": "test",
"redirect_uri": "http://localhost", "redirect_uri": "http://foo.bar.baz",
}, },
) )
OAuthAuthorizationParams.from_request(request) OAuthAuthorizationParams.from_request(request)
@ -258,7 +263,7 @@ class TestAuthorize(OAuthTestCase):
GrantTypes.IMPLICIT, GrantTypes.IMPLICIT,
) )
# Implicit without openid scope # Implicit without openid scope
with self.assertRaises(AuthorizeError): with self.assertRaises(AuthorizeError) as cm:
request = self.factory.get( request = self.factory.get(
"/", "/",
data={ data={
@ -285,7 +290,7 @@ class TestAuthorize(OAuthTestCase):
self.assertEqual( self.assertEqual(
OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID OAuthAuthorizationParams.from_request(request).grant_type, GrantTypes.HYBRID
) )
with self.assertRaises(AuthorizeError): with self.assertRaises(AuthorizeError) as cm:
request = self.factory.get( request = self.factory.get(
"/", "/",
data={ data={
@ -295,6 +300,7 @@ class TestAuthorize(OAuthTestCase):
}, },
) )
OAuthAuthorizationParams.from_request(request) OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.error, "unsupported_response_type")
def test_full_code(self): def test_full_code(self):
"""Test full authorization""" """Test full authorization"""
@ -387,7 +393,8 @@ class TestAuthorize(OAuthTestCase):
self.assertEqual( self.assertEqual(
response.url, response.url,
( (
f"http://localhost#id_token={provider.encode(token.id_token.to_dict())}" f"http://localhost#access_token={token.token}"
f"&id_token={provider.encode(token.id_token.to_dict())}"
f"&token_type={TOKEN_TYPE}" f"&token_type={TOKEN_TYPE}"
f"&expires_in={int(expires)}&state={state}" f"&expires_in={int(expires)}&state={state}"
), ),
@ -562,6 +569,7 @@ class TestAuthorize(OAuthTestCase):
"url": "http://localhost", "url": "http://localhost",
"title": f"Redirecting to {app.name}...", "title": f"Redirecting to {app.name}...",
"attrs": { "attrs": {
"access_token": token.token,
"id_token": provider.encode(token.id_token.to_dict()), "id_token": provider.encode(token.id_token.to_dict()),
"token_type": TOKEN_TYPE, "token_type": TOKEN_TYPE,
"expires_in": "3600", "expires_in": "3600",
@ -613,3 +621,54 @@ class TestAuthorize(OAuthTestCase):
}, },
}, },
) )
def test_openid_missing_invalid(self):
"""test request requiring an OpenID scope to be set"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
)
request = self.factory.get(
"/",
data={
"response_type": "id_token",
"client_id": "test",
"redirect_uri": "http://localhost",
"scope": "",
},
)
with self.assertRaises(AuthorizeError) as cm:
OAuthAuthorizationParams.from_request(request)
self.assertEqual(cm.exception.cause, "scope_openid_missing")
@apply_blueprint("system/providers-oauth2.yaml")
def test_offline_access_invalid(self):
"""test request for offline_access with invalid response type"""
provider = OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "http://localhost")],
)
provider.property_mappings.set(
ScopeMapping.objects.filter(
managed__in=[
"goauthentik.io/providers/oauth2/scope-openid",
"goauthentik.io/providers/oauth2/scope-offline_access",
]
)
)
request = self.factory.get(
"/",
data={
"response_type": "id_token",
"client_id": "test",
"redirect_uri": "http://localhost",
"scope": f"{SCOPE_OPENID} {SCOPE_OFFLINE_ACCESS}",
"nonce": generate_id(),
},
)
parsed = OAuthAuthorizationParams.from_request(request)
self.assertNotIn(SCOPE_OFFLINE_ACCESS, parsed.scope)

View File

@ -150,12 +150,12 @@ class OAuthAuthorizationParams:
self.check_redirect_uri() self.check_redirect_uri()
self.check_grant() self.check_grant()
self.check_scope(github_compat) self.check_scope(github_compat)
self.check_nonce()
self.check_code_challenge()
if self.request: if self.request:
raise AuthorizeError( raise AuthorizeError(
self.redirect_uri, "request_not_supported", self.grant_type, self.state self.redirect_uri, "request_not_supported", self.grant_type, self.state
) )
self.check_nonce()
self.check_code_challenge()
def check_grant(self): def check_grant(self):
"""Check grant""" """Check grant"""
@ -190,7 +190,7 @@ class OAuthAuthorizationParams:
allowed_redirect_urls = self.provider.redirect_uris allowed_redirect_urls = self.provider.redirect_uris
if not self.redirect_uri: if not self.redirect_uri:
LOGGER.warning("Missing redirect uri.") LOGGER.warning("Missing redirect uri.")
raise RedirectUriError("", allowed_redirect_urls) raise RedirectUriError("", allowed_redirect_urls).with_cause("redirect_uri_missing")
if len(allowed_redirect_urls) < 1: if len(allowed_redirect_urls) < 1:
LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri) LOGGER.info("Setting redirect for blank redirect_uris", redirect=self.redirect_uri)
@ -219,10 +219,14 @@ class OAuthAuthorizationParams:
provider=self.provider, provider=self.provider,
) )
if not match_found: if not match_found:
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
"redirect_uri_no_match"
)
# Check against forbidden schemes # Check against forbidden schemes
if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES: if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) raise RedirectUriError(self.redirect_uri, allowed_redirect_urls).with_cause(
"redirect_uri_forbidden_scheme"
)
def check_scope(self, github_compat=False): def check_scope(self, github_compat=False):
"""Ensure openid scope is set in Hybrid flows, or when requesting an id_token""" """Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
@ -251,7 +255,9 @@ class OAuthAuthorizationParams:
or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN] or self.response_type in [ResponseTypes.ID_TOKEN, ResponseTypes.ID_TOKEN_TOKEN]
): ):
LOGGER.warning("Missing 'openid' scope.") LOGGER.warning("Missing 'openid' scope.")
raise AuthorizeError(self.redirect_uri, "invalid_scope", self.grant_type, self.state) raise AuthorizeError(
self.redirect_uri, "invalid_scope", self.grant_type, self.state
).with_cause("scope_openid_missing")
if SCOPE_OFFLINE_ACCESS in self.scope: if SCOPE_OFFLINE_ACCESS in self.scope:
# https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess # https://openid.net/specs/openid-connect-core-1_0.html#OfflineAccess
# Don't explicitly request consent with offline_access, as the spec allows for # Don't explicitly request consent with offline_access, as the spec allows for
@ -286,7 +292,9 @@ class OAuthAuthorizationParams:
return return
if not self.nonce: if not self.nonce:
LOGGER.warning("Missing nonce for OpenID Request") LOGGER.warning("Missing nonce for OpenID Request")
raise AuthorizeError(self.redirect_uri, "invalid_request", self.grant_type, self.state) raise AuthorizeError(
self.redirect_uri, "invalid_request", self.grant_type, self.state
).with_cause("none_missing")
def check_code_challenge(self): def check_code_challenge(self):
"""PKCE validation of the transformation method.""" """PKCE validation of the transformation method."""
@ -345,10 +353,10 @@ class AuthorizationFlowInitView(PolicyAccessView):
self.request, github_compat=self.github_compat self.request, github_compat=self.github_compat
) )
except AuthorizeError as error: except AuthorizeError as error:
LOGGER.warning(error.description, redirect_uri=error.redirect_uri) LOGGER.warning(error.description, redirect_uri=error.redirect_uri, cause=error.cause)
raise RequestValidationError(error.get_response(self.request)) from None raise RequestValidationError(error.get_response(self.request)) from None
except OAuth2Error as error: except OAuth2Error as error:
LOGGER.warning(error.description) LOGGER.warning(error.description, cause=error.cause)
raise RequestValidationError( raise RequestValidationError(
bad_request_message(self.request, error.description, title=error.error) bad_request_message(self.request, error.description, title=error.error)
) from None ) from None
@ -630,6 +638,7 @@ class OAuthFulfillmentStage(StageView):
if self.params.response_type in [ if self.params.response_type in [
ResponseTypes.ID_TOKEN_TOKEN, ResponseTypes.ID_TOKEN_TOKEN,
ResponseTypes.CODE_ID_TOKEN_TOKEN, ResponseTypes.CODE_ID_TOKEN_TOKEN,
ResponseTypes.ID_TOKEN,
ResponseTypes.CODE_TOKEN, ResponseTypes.CODE_TOKEN,
]: ]:
query_fragment["access_token"] = token.token query_fragment["access_token"] = token.token

View File

@ -40,16 +40,9 @@ class ConnectionTokenViewSet(
): ):
"""ConnectionToken Viewset""" """ConnectionToken Viewset"""
queryset = ConnectionToken.objects.none() queryset = ConnectionToken.objects.all().select_related("session", "endpoint")
serializer_class = ConnectionTokenSerializer serializer_class = ConnectionTokenSerializer
filterset_fields = ["endpoint", "session__user"] filterset_fields = ["endpoint", "session__user", "provider"]
search_fields = ["endpoint__name", "session__user__username"] search_fields = ["endpoint__name", "provider__name"]
ordering = ["endpoint__name", "session__user__username"] ordering = ["endpoint__name", "provider__name"]
owner_field = "session__user" owner_field = "session__user"
def get_queryset(self):
return (
ConnectionToken.objects.all()
.select_related("session", "endpoint")
.filter(provider=self.kwargs["provider_pk"])
)

View File

@ -22,9 +22,9 @@ from authentik.rbac.filters import ObjectFilter
LOGGER = get_logger() LOGGER = get_logger()
def user_endpoint_cache_key(user_pk: str, provider_pk: str) -> str: def user_endpoint_cache_key(user_pk: str) -> str:
"""Cache key where endpoint list for user is saved""" """Cache key where endpoint list for user is saved"""
return f"goauthentik.io/providers/rac/endpoint_access/{user_pk}/{provider_pk}" return f"goauthentik.io/providers/rac/endpoint_access/{user_pk}"
class EndpointSerializer(ModelSerializer): class EndpointSerializer(ModelSerializer):
@ -65,15 +65,12 @@ class EndpointSerializer(ModelSerializer):
class EndpointViewSet(UsedByMixin, ModelViewSet): class EndpointViewSet(UsedByMixin, ModelViewSet):
"""Endpoint Viewset""" """Endpoint Viewset"""
queryset = Endpoint.objects.none() queryset = Endpoint.objects.all()
serializer_class = EndpointSerializer serializer_class = EndpointSerializer
filterset_fields = ["name"] filterset_fields = ["name", "provider"]
search_fields = ["name", "protocol"] search_fields = ["name", "protocol"]
ordering = ["name", "protocol"] ordering = ["name", "protocol"]
def get_queryset(self):
return Endpoint.objects.filter(provider=self.kwargs["provider_pk"])
def _filter_queryset_for_list(self, queryset: QuerySet) -> QuerySet: def _filter_queryset_for_list(self, queryset: QuerySet) -> QuerySet:
"""Custom filter_queryset method which ignores guardian, but still supports sorting""" """Custom filter_queryset method which ignores guardian, but still supports sorting"""
for backend in list(self.filter_backends): for backend in list(self.filter_backends):
@ -123,11 +120,14 @@ class EndpointViewSet(UsedByMixin, ModelViewSet):
if not should_cache: if not should_cache:
allowed_endpoints = self._get_allowed_endpoints(queryset) allowed_endpoints = self._get_allowed_endpoints(queryset)
if should_cache: if should_cache:
key = user_endpoint_cache_key(self.request.user.pk, self.kwargs["provider_pk"]) allowed_endpoints = cache.get(user_endpoint_cache_key(self.request.user.pk))
allowed_endpoints = cache.get(key)
if not allowed_endpoints: if not allowed_endpoints:
LOGGER.debug("Caching allowed endpoint list") LOGGER.debug("Caching allowed endpoint list")
allowed_endpoints = self._get_allowed_endpoints(queryset) allowed_endpoints = self._get_allowed_endpoints(queryset)
cache.set(key, allowed_endpoints, timeout=86400) cache.set(
user_endpoint_cache_key(self.request.user.pk),
allowed_endpoints,
timeout=86400,
)
serializer = self.get_serializer(allowed_endpoints, many=True) serializer = self.get_serializer(allowed_endpoints, many=True)
return self.get_paginated_response(serializer.data) return self.get_paginated_response(serializer.data)

View File

@ -66,10 +66,7 @@ class RACClientConsumer(AsyncWebsocketConsumer):
def init_outpost_connection(self): def init_outpost_connection(self):
"""Initialize guac connection settings""" """Initialize guac connection settings"""
self.token = ( self.token = (
ConnectionToken.filter_not_expired( ConnectionToken.filter_not_expired(token=self.scope["url_route"]["kwargs"]["token"])
token=self.scope["url_route"]["kwargs"]["token"],
session__session__session_key=self.scope["session"].session_key,
)
.select_related("endpoint", "provider", "session", "session__user") .select_related("endpoint", "provider", "session", "session__user")
.first() .first()
) )

View File

@ -2,11 +2,13 @@
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer from channels.layers import get_channel_layer
from django.contrib.auth.signals import user_logged_out
from django.core.cache import cache from django.core.cache import cache
from django.db.models.signals import post_delete, post_save, pre_delete from django.db.models.signals import post_delete, post_save, pre_delete
from django.dispatch import receiver from django.dispatch import receiver
from django.http import HttpRequest
from authentik.core.models import AuthenticatedSession from authentik.core.models import AuthenticatedSession, User
from authentik.providers.rac.api.endpoints import user_endpoint_cache_key from authentik.providers.rac.api.endpoints import user_endpoint_cache_key
from authentik.providers.rac.consumer_client import ( from authentik.providers.rac.consumer_client import (
RAC_CLIENT_GROUP_SESSION, RAC_CLIENT_GROUP_SESSION,
@ -15,6 +17,21 @@ from authentik.providers.rac.consumer_client import (
from authentik.providers.rac.models import ConnectionToken, Endpoint from authentik.providers.rac.models import ConnectionToken, Endpoint
@receiver(user_logged_out)
def user_logged_out_session(sender, request: HttpRequest, user: User, **_):
"""Disconnect any open RAC connections"""
if not request.session or not request.session.session_key:
return
layer = get_channel_layer()
async_to_sync(layer.group_send)(
RAC_CLIENT_GROUP_SESSION
% {
"session": request.session.session_key,
},
{"type": "event.disconnect", "reason": "session_logout"},
)
@receiver(pre_delete, sender=AuthenticatedSession) @receiver(pre_delete, sender=AuthenticatedSession)
def user_session_deleted(sender, instance: AuthenticatedSession, **_): def user_session_deleted(sender, instance: AuthenticatedSession, **_):
layer = get_channel_layer() layer = get_channel_layer()
@ -43,5 +60,5 @@ def pre_delete_connection_token_disconnect(sender, instance: ConnectionToken, **
@receiver([post_save, post_delete], sender=Endpoint) @receiver([post_save, post_delete], sender=Endpoint)
def post_save_post_delete_endpoint(**_): def post_save_post_delete_endpoint(**_):
"""Clear user's endpoint cache upon endpoint creation or deletion""" """Clear user's endpoint cache upon endpoint creation or deletion"""
keys = cache.keys(user_endpoint_cache_key("*", "*")) keys = cache.keys(user_endpoint_cache_key("*"))
cache.delete_many(keys) cache.delete_many(keys)

View File

@ -49,7 +49,6 @@ class TestEndpointsAPI(APITestCase):
self.assertJSONEqual( self.assertJSONEqual(
response.content.decode(), response.content.decode(),
{ {
"autocomplete": {},
"pagination": { "pagination": {
"next": 0, "next": 0,
"previous": 0, "previous": 0,
@ -102,7 +101,6 @@ class TestEndpointsAPI(APITestCase):
self.assertJSONEqual( self.assertJSONEqual(
response.content.decode(), response.content.decode(),
{ {
"autocomplete": {},
"pagination": { "pagination": {
"next": 0, "next": 0,
"previous": 0, "previous": 0,

View File

@ -87,22 +87,3 @@ class TestRACViews(APITestCase):
) )
body = loads(flow_response.content) body = loads(flow_response.content)
self.assertEqual(body["component"], "ak-stage-access-denied") self.assertEqual(body["component"], "ak-stage-access-denied")
def test_different_session(self):
"""Test request"""
self.client.force_login(self.user)
response = self.client.get(
reverse(
"authentik_providers_rac:start",
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
)
)
self.assertEqual(response.status_code, 302)
flow_response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
)
body = loads(flow_response.content)
next_url = body["to"]
self.client.logout()
final_response = self.client.get(next_url)
self.assertEqual(final_response.url, reverse("authentik_core:if-user"))

View File

@ -2,7 +2,6 @@
from django.urls import path from django.urls import path
from authentik.api.v3.routers import NestedRouter
from authentik.outposts.channels import TokenOutpostMiddleware from authentik.outposts.channels import TokenOutpostMiddleware
from authentik.providers.rac.api.connection_tokens import ConnectionTokenViewSet from authentik.providers.rac.api.connection_tokens import ConnectionTokenViewSet
from authentik.providers.rac.api.endpoints import EndpointViewSet from authentik.providers.rac.api.endpoints import EndpointViewSet
@ -39,10 +38,8 @@ websocket_urlpatterns = [
] ]
api_urlpatterns = [ api_urlpatterns = [
*NestedRouter() ("providers/rac", RACProviderViewSet),
.register("providers/rac", RACProviderViewSet)
.nested("provider", "endpoints", EndpointViewSet)
.nested("provider", "connection_tokens", ConnectionTokenViewSet)
.urls,
("propertymappings/provider/rac", RACPropertyMappingViewSet), ("propertymappings/provider/rac", RACPropertyMappingViewSet),
("rac/endpoints", EndpointViewSet),
("rac/connection_tokens", ConnectionTokenViewSet),
] ]

View File

@ -68,10 +68,7 @@ class RACInterface(InterfaceView):
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
# Early sanity check to ensure token still exists # Early sanity check to ensure token still exists
token = ConnectionToken.filter_not_expired( token = ConnectionToken.filter_not_expired(token=self.kwargs["token"]).first()
token=self.kwargs["token"],
session__session__session_key=request.session.session_key,
).first()
if not token: if not token:
return redirect("authentik_core:if-user") return redirect("authentik_core:if-user")
self.token = token self.token = token

View File

@ -5,6 +5,7 @@ from itertools import batched
from django.db import transaction from django.db import transaction
from pydantic import ValidationError from pydantic import ValidationError
from pydanticscim.group import GroupMember from pydanticscim.group import GroupMember
from pydanticscim.responses import PatchOp
from authentik.core.models import Group from authentik.core.models import Group
from authentik.lib.sync.mapper import PropertyMappingManager from authentik.lib.sync.mapper import PropertyMappingManager
@ -19,12 +20,7 @@ from authentik.providers.scim.clients.base import SCIMClient
from authentik.providers.scim.clients.exceptions import ( from authentik.providers.scim.clients.exceptions import (
SCIMRequestException, SCIMRequestException,
) )
from authentik.providers.scim.clients.schema import ( from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest
SCIM_GROUP_SCHEMA,
PatchOp,
PatchOperation,
PatchRequest,
)
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
from authentik.providers.scim.models import ( from authentik.providers.scim.models import (
SCIMMapping, SCIMMapping,

View File

@ -1,7 +1,5 @@
"""Custom SCIM schemas""" """Custom SCIM schemas"""
from enum import Enum
from pydantic import Field from pydantic import Field
from pydanticscim.group import Group as BaseGroup from pydanticscim.group import Group as BaseGroup
from pydanticscim.responses import PatchOperation as BasePatchOperation from pydanticscim.responses import PatchOperation as BasePatchOperation
@ -67,21 +65,6 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
) )
class PatchOp(str, Enum):
replace = "replace"
remove = "remove"
add = "add"
@classmethod
def _missing_(cls, value):
value = value.lower()
for member in cls:
if member.lower() == value:
return member
return None
class PatchRequest(BasePatchRequest): class PatchRequest(BasePatchRequest):
"""PatchRequest which correctly sets schemas""" """PatchRequest which correctly sets schemas"""
@ -91,7 +74,6 @@ class PatchRequest(BasePatchRequest):
class PatchOperation(BasePatchOperation): class PatchOperation(BasePatchOperation):
"""PatchOperation with optional path""" """PatchOperation with optional path"""
op: PatchOp
path: str | None path: str | None

View File

@ -44,7 +44,6 @@ class TestRBACRoleAPI(APITestCase):
self.assertJSONEqual( self.assertJSONEqual(
res.content.decode(), res.content.decode(),
{ {
"autocomplete": {},
"pagination": { "pagination": {
"next": 0, "next": 0,
"previous": 0, "previous": 0,

View File

@ -46,7 +46,6 @@ class TestRBACUserAPI(APITestCase):
self.assertJSONEqual( self.assertJSONEqual(
res.content.decode(), res.content.decode(),
{ {
"autocomplete": {},
"pagination": { "pagination": {
"next": 0, "next": 0,
"previous": 0, "previous": 0,

View File

@ -38,7 +38,6 @@ class TestAPIPerms(APITestCase):
self.assertJSONEqual( self.assertJSONEqual(
res.content.decode(), res.content.decode(),
{ {
"autocomplete": {},
"pagination": { "pagination": {
"next": 0, "next": 0,
"previous": 0, "previous": 0,
@ -74,7 +73,6 @@ class TestAPIPerms(APITestCase):
self.assertJSONEqual( self.assertJSONEqual(
res.content.decode(), res.content.decode(),
{ {
"autocomplete": {},
"pagination": { "pagination": {
"next": 0, "next": 0,
"previous": 0, "previous": 0,

View File

@ -9,14 +9,13 @@ https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/
import django import django
from channels.routing import ProtocolTypeRouter, URLRouter from channels.routing import ProtocolTypeRouter, URLRouter
from defusedxml import defuse_stdlib
from django.core.asgi import get_asgi_application from django.core.asgi import get_asgi_application
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
from authentik.root.setup import setup
# DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py # DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py
setup() defuse_stdlib()
django.setup() django.setup()

View File

@ -27,7 +27,7 @@ from structlog.stdlib import get_logger
from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp
from authentik import get_full_version from authentik import get_full_version
from authentik.lib.sentry import should_ignore_exception from authentik.lib.sentry import before_send
from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.errors import exception_to_string
# set the default Django settings module for the 'celery' program. # set the default Django settings module for the 'celery' program.
@ -81,7 +81,7 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar
LOGGER.warning("Task failure", task_id=task_id.replace("-", ""), exc=exception) LOGGER.warning("Task failure", task_id=task_id.replace("-", ""), exc=exception)
CTX_TASK_ID.set(...) CTX_TASK_ID.set(...)
if not should_ignore_exception(exception): if before_send({}, {"exc_info": (None, exception, None)}) is not None:
Event.new( Event.new(
EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id
).save() ).save()

View File

@ -1,49 +1,13 @@
"""authentik database backend""" """authentik database backend"""
from django.core.checks import Warning
from django.db.backends.base.validation import BaseDatabaseValidation
from django_tenants.postgresql_backend.base import DatabaseWrapper as BaseDatabaseWrapper from django_tenants.postgresql_backend.base import DatabaseWrapper as BaseDatabaseWrapper
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
class DatabaseValidation(BaseDatabaseValidation):
def check(self, **kwargs):
return self._check_encoding()
def _check_encoding(self):
"""Throw a warning when the server_encoding is not UTF-8 or
server_encoding and client_encoding are mismatched"""
messages = []
with self.connection.cursor() as cursor:
cursor.execute("SHOW server_encoding;")
server_encoding = cursor.fetchone()[0]
cursor.execute("SHOW client_encoding;")
client_encoding = cursor.fetchone()[0]
if server_encoding != client_encoding:
messages.append(
Warning(
"PostgreSQL Server and Client encoding are mismatched: Server: "
f"{server_encoding}, Client: {client_encoding}",
id="ak.db.W001",
)
)
if server_encoding != "UTF8":
messages.append(
Warning(
f"PostgreSQL Server encoding is not UTF8: {server_encoding}",
id="ak.db.W002",
)
)
return messages
class DatabaseWrapper(BaseDatabaseWrapper): class DatabaseWrapper(BaseDatabaseWrapper):
"""database backend which supports rotating credentials""" """database backend which supports rotating credentials"""
validation_class = DatabaseValidation
def get_connection_params(self): def get_connection_params(self):
"""Refresh DB credentials before getting connection params""" """Refresh DB credentials before getting connection params"""
conn_params = super().get_connection_params() conn_params = super().get_connection_params()

View File

@ -446,8 +446,6 @@ _DISALLOWED_ITEMS = [
"MIDDLEWARE", "MIDDLEWARE",
"AUTHENTICATION_BACKENDS", "AUTHENTICATION_BACKENDS",
"CELERY", "CELERY",
"SPECTACULAR_SETTINGS",
"REST_FRAMEWORK",
] ]
SILENCED_SYSTEM_CHECKS = [ SILENCED_SYSTEM_CHECKS = [
@ -470,8 +468,6 @@ def _update_settings(app_path: str):
TENANT_APPS.extend(getattr(settings_module, "TENANT_APPS", [])) TENANT_APPS.extend(getattr(settings_module, "TENANT_APPS", []))
MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", [])) MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", []))
AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", [])) AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", []))
SPECTACULAR_SETTINGS.update(getattr(settings_module, "SPECTACULAR_SETTINGS", {}))
REST_FRAMEWORK.update(getattr(settings_module, "REST_FRAMEWORK", {}))
CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {})) CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {}))
for _attr in dir(settings_module): for _attr in dir(settings_module):
if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS: if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS:

View File

@ -1,26 +0,0 @@
import os
import warnings
from cryptography.hazmat.backends.openssl.backend import backend
from defusedxml import defuse_stdlib
from authentik.lib.config import CONFIG
def setup():
warnings.filterwarnings("ignore", "SelectableGroups dict interface")
warnings.filterwarnings(
"ignore",
"defusedxml.lxml is no longer supported and will be removed in a future release.",
)
warnings.filterwarnings(
"ignore",
"defusedxml.cElementTree is deprecated, import from defusedxml.ElementTree instead.",
)
defuse_stdlib()
if CONFIG.get_bool("compliance.fips.enabled", False):
backend._enable_fips()
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")

View File

@ -11,8 +11,6 @@ from django.contrib.contenttypes.models import ContentType
from django.test.runner import DiscoverRunner from django.test.runner import DiscoverRunner
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
from authentik.lib.config import CONFIG from authentik.lib.config import CONFIG
from authentik.lib.sentry import sentry_init from authentik.lib.sentry import sentry_init
from authentik.root.signals import post_startup, pre_startup, startup from authentik.root.signals import post_startup, pre_startup, startup
@ -78,9 +76,6 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover
for key, value in test_config.items(): for key, value in test_config.items():
CONFIG.set(key, value) CONFIG.set(key, value)
ASN_CONTEXT_PROCESSOR.load()
GEOIP_CONTEXT_PROCESSOR.load()
sentry_init() sentry_init()
self.logger.debug("Test environment configured") self.logger.debug("Test environment configured")

View File

@ -71,31 +71,37 @@ def ldap_sync_single(source_pk: str):
return return
# Delete all sync tasks from the cache # Delete all sync tasks from the cache
DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete() DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete()
task = chain(
# The order of these operations needs to be preserved as each depends on the previous one(s) # User and group sync can happen at once, they have no dependencies on each other
# 1. User and group sync can happen simultaneously group(
# 2. Membership sync needs to run afterwards ldap_sync_paginator(source, UserLDAPSynchronizer)
# 3. Finally, user and group deletions can happen simultaneously + ldap_sync_paginator(source, GroupLDAPSynchronizer),
user_group_sync = ldap_sync_paginator(source, UserLDAPSynchronizer) + ldap_sync_paginator( ),
source, GroupLDAPSynchronizer # Membership sync needs to run afterwards
group(
ldap_sync_paginator(source, MembershipLDAPSynchronizer),
),
# Finally, deletions. What we'd really like to do here is something like
# ```
# user_identifiers = <ldap query>
# User.objects.exclude(
# usersourceconnection__identifier__in=user_uniqueness_identifiers,
# ).delete()
# ```
# This runs into performance issues in large installations. So instead we spread the
# work out into three steps:
# 1. Get every object from the LDAP source.
# 2. Mark every object as "safe" in the database. This is quick, but any error could
# mean deleting users which should not be deleted, so we do it immediately, in
# large chunks, and only queue the deletion step afterwards.
# 3. Delete every unmarked item. This is slow, so we spread it over many tasks in
# small chunks.
group(
ldap_sync_paginator(source, UserLDAPForwardDeletion)
+ ldap_sync_paginator(source, GroupLDAPForwardDeletion),
),
) )
membership_sync = ldap_sync_paginator(source, MembershipLDAPSynchronizer) task()
user_group_deletion = ldap_sync_paginator(
source, UserLDAPForwardDeletion
) + ldap_sync_paginator(source, GroupLDAPForwardDeletion)
# Celery is buggy with empty groups, so we are careful only to add non-empty groups.
# See https://github.com/celery/celery/issues/9772
task_groups = []
if user_group_sync:
task_groups.append(group(user_group_sync))
if membership_sync:
task_groups.append(group(membership_sync))
if user_group_deletion:
task_groups.append(group(user_group_deletion))
all_tasks = chain(task_groups)
all_tasks()
def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list: def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list:

View File

@ -1,277 +0,0 @@
"""Test SCIM Group"""
from json import dumps
from uuid import uuid4
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.models import Group
from authentik.core.tests.utils import create_test_user
from authentik.events.models import Event, EventAction
from authentik.lib.generators import generate_id
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
from authentik.sources.scim.models import (
SCIMSource,
SCIMSourceGroup,
)
from authentik.sources.scim.views.v2.base import SCIM_CONTENT_TYPE
class TestSCIMGroups(APITestCase):
"""Test SCIM Group view"""
def setUp(self) -> None:
self.source = SCIMSource.objects.create(name=generate_id(), slug=generate_id())
def test_group_list(self):
"""Test full group list"""
response = self.client.get(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 200)
def test_group_list_single(self):
"""Test full group list (single group)"""
group = Group.objects.create(name=generate_id())
user = create_test_user()
group.users.add(user)
SCIMSourceGroup.objects.create(
source=self.source,
group=group,
id=str(uuid4()),
)
response = self.client.get(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
"group_id": str(group.pk),
},
),
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=200)
SCIMGroupSchema.model_validate_json(response.content, strict=True)
def test_group_create(self):
"""Test group create"""
ext_id = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
data=dumps({"displayName": generate_id(), "externalId": ext_id}),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 201)
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
self.assertTrue(
Event.objects.filter(
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
).exists()
)
def test_group_create_members(self):
"""Test group create"""
user = create_test_user()
ext_id = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
data=dumps(
{
"displayName": generate_id(),
"externalId": ext_id,
"members": [{"value": str(user.uuid)}],
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 201)
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
self.assertTrue(
Event.objects.filter(
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
).exists()
)
def test_group_create_members_empty(self):
"""Test group create"""
ext_id = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
data=dumps({"displayName": generate_id(), "externalId": ext_id, "members": []}),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 201)
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
self.assertTrue(
Event.objects.filter(
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
).exists()
)
def test_group_create_duplicate(self):
"""Test group create (duplicate)"""
group = Group.objects.create(name=generate_id())
existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
ext_id = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
data=dumps(
{"displayName": generate_id(), "externalId": ext_id, "id": str(existing.group.pk)}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 409)
self.assertJSONEqual(
response.content,
{
"detail": "Group with ID exists already.",
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"scimType": "uniqueness",
"status": 409,
},
)
def test_group_update(self):
"""Test group update"""
group = Group.objects.create(name=generate_id())
existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
ext_id = generate_id()
response = self.client.put(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
data=dumps(
{"displayName": generate_id(), "externalId": ext_id, "id": str(existing.pk)}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=200)
def test_group_update_non_existent(self):
"""Test group update"""
ext_id = generate_id()
response = self.client.put(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
"group_id": str(uuid4()),
},
),
data=dumps({"displayName": generate_id(), "externalId": ext_id, "id": ""}),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=404)
self.assertJSONEqual(
response.content,
{
"detail": "Group not found.",
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"status": 404,
},
)
def test_group_patch_add(self):
"""Test group patch"""
user = create_test_user()
group = Group.objects.create(name=generate_id())
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
response = self.client.patch(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
data=dumps(
{
"Operations": [
{
"op": "Add",
"path": "members",
"value": {"value": str(user.uuid)},
}
]
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=200)
self.assertTrue(group.users.filter(pk=user.pk).exists())
def test_group_patch_remove(self):
"""Test group patch"""
user = create_test_user()
group = Group.objects.create(name=generate_id())
group.users.add(user)
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
response = self.client.patch(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
data=dumps(
{
"Operations": [
{
"op": "remove",
"path": "members",
"value": {"value": str(user.uuid)},
}
]
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=200)
self.assertFalse(group.users.filter(pk=user.pk).exists())
def test_group_delete(self):
"""Test group delete"""
group = Group.objects.create(name=generate_id())
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
response = self.client.delete(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=204)

View File

@ -177,51 +177,3 @@ class TestSCIMUsers(APITestCase):
SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"], SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"],
"0123456789", "0123456789",
) )
def test_user_update(self):
"""Test user update"""
user = create_test_user()
existing = SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
ext_id = generate_id()
response = self.client.put(
reverse(
"authentik_sources_scim:v2-users",
kwargs={
"source_slug": self.source.slug,
"user_id": str(user.uuid),
},
),
data=dumps(
{
"id": str(existing.pk),
"userName": generate_id(),
"externalId": ext_id,
"emails": [
{
"primary": True,
"value": user.email,
}
],
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 200)
def test_user_delete(self):
"""Test user delete"""
user = create_test_user()
SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
response = self.client.delete(
reverse(
"authentik_sources_scim:v2-users",
kwargs={
"source_slug": self.source.slug,
"user_id": str(user.uuid),
},
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 204)

View File

@ -8,7 +8,6 @@ from rest_framework.authentication import BaseAuthentication, get_authorization_
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.views import APIView from rest_framework.views import APIView
from authentik.core.middleware import CTX_AUTH_VIA
from authentik.core.models import Token, TokenIntents, User from authentik.core.models import Token, TokenIntents, User
from authentik.sources.scim.models import SCIMSource from authentik.sources.scim.models import SCIMSource
@ -27,7 +26,6 @@ class SCIMTokenAuth(BaseAuthentication):
_username, _, password = b64decode(key.encode()).decode().partition(":") _username, _, password = b64decode(key.encode()).decode().partition(":")
token = self.check_token(password, source_slug) token = self.check_token(password, source_slug)
if token: if token:
CTX_AUTH_VIA.set("scim_basic")
return (token.user, token) return (token.user, token)
return None return None
@ -54,5 +52,4 @@ class SCIMTokenAuth(BaseAuthentication):
token = self.check_token(key, source_slug) token = self.check_token(key, source_slug)
if not token: if not token:
return None return None
CTX_AUTH_VIA.set("scim_token")
return (token.user, token) return (token.user, token)

View File

@ -1,11 +1,13 @@
"""SCIM Utils""" """SCIM Utils"""
from typing import Any from typing import Any
from urllib.parse import urlparse
from django.conf import settings from django.conf import settings
from django.core.paginator import Page, Paginator from django.core.paginator import Page, Paginator
from django.db.models import Q, QuerySet from django.db.models import Q, QuerySet
from django.http import HttpRequest from django.http import HttpRequest
from django.urls import resolve
from rest_framework.parsers import JSONParser from rest_framework.parsers import JSONParser
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
from rest_framework.renderers import JSONRenderer from rest_framework.renderers import JSONRenderer
@ -44,7 +46,7 @@ class SCIMView(APIView):
logger: BoundLogger logger: BoundLogger
permission_classes = [IsAuthenticated] permission_classes = [IsAuthenticated]
parser_classes = [SCIMParser, JSONParser] parser_classes = [SCIMParser]
renderer_classes = [SCIMRenderer] renderer_classes = [SCIMRenderer]
def setup(self, request: HttpRequest, *args: Any, **kwargs: Any) -> None: def setup(self, request: HttpRequest, *args: Any, **kwargs: Any) -> None:
@ -54,6 +56,28 @@ class SCIMView(APIView):
def get_authenticators(self): def get_authenticators(self):
return [SCIMTokenAuth(self)] return [SCIMTokenAuth(self)]
def patch_resolve_value(self, raw_value: dict) -> User | Group | None:
"""Attempt to resolve a raw `value` attribute of a patch operation into
a database model"""
model = User
query = {}
if "$ref" in raw_value:
url = urlparse(raw_value["$ref"])
if match := resolve(url.path):
if match.url_name == "v2-users":
model = User
query = {"pk": int(match.kwargs["user_id"])}
elif "type" in raw_value:
match raw_value["type"]:
case "User":
model = User
query = {"pk": int(raw_value["value"])}
case "Group":
model = Group
else:
return None
return model.objects.filter(**query).first()
def filter_parse(self, request: Request): def filter_parse(self, request: Request):
"""Parse the path of a Patch Operation""" """Parse the path of a Patch Operation"""
path = request.query_params.get("filter") path = request.query_params.get("filter")

View File

@ -1,58 +0,0 @@
from enum import Enum
from pydanticscim.responses import SCIMError as BaseSCIMError
from rest_framework.exceptions import ValidationError
class SCIMErrorTypes(Enum):
invalid_filter = "invalidFilter"
too_many = "tooMany"
uniqueness = "uniqueness"
mutability = "mutability"
invalid_syntax = "invalidSyntax"
invalid_path = "invalidPath"
no_target = "noTarget"
invalid_value = "invalidValue"
invalid_vers = "invalidVers"
sensitive = "sensitive"
class SCIMError(BaseSCIMError):
scimType: SCIMErrorTypes | None = None
detail: str | None = None
class SCIMValidationError(ValidationError):
status_code = 400
default_detail = SCIMError(scimType=SCIMErrorTypes.invalid_syntax, status=400)
def __init__(self, detail: SCIMError | None):
if detail is None:
detail = self.default_detail
detail.status = self.status_code
self.detail = detail.model_dump(mode="json", exclude_none=True)
class SCIMConflictError(SCIMValidationError):
status_code = 409
def __init__(self, detail: str):
super().__init__(
SCIMError(
detail=detail,
scimType=SCIMErrorTypes.uniqueness,
status=self.status_code,
)
)
class SCIMNotFoundError(SCIMValidationError):
status_code = 404
def __init__(self, detail: str):
super().__init__(
SCIMError(
detail=detail,
status=self.status_code,
)
)

View File

@ -4,25 +4,19 @@ from uuid import uuid4
from django.db.models import Q from django.db.models import Q
from django.db.transaction import atomic from django.db.transaction import atomic
from django.http import QueryDict from django.http import Http404, QueryDict
from django.urls import reverse from django.urls import reverse
from pydantic import ValidationError as PydanticValidationError from pydantic import ValidationError as PydanticValidationError
from pydanticscim.group import GroupMember from pydanticscim.group import GroupMember
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from scim2_filter_parser.attr_paths import AttrPath
from authentik.core.models import Group, User from authentik.core.models import Group, User
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOp, PatchOperation from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
from authentik.providers.scim.clients.schema import Group as SCIMGroupModel from authentik.providers.scim.clients.schema import Group as SCIMGroupModel
from authentik.sources.scim.models import SCIMSourceGroup from authentik.sources.scim.models import SCIMSourceGroup
from authentik.sources.scim.views.v2.base import SCIMObjectView from authentik.sources.scim.views.v2.base import SCIMObjectView
from authentik.sources.scim.views.v2.exceptions import (
SCIMConflictError,
SCIMNotFoundError,
SCIMValidationError,
)
class GroupsView(SCIMObjectView): class GroupsView(SCIMObjectView):
@ -33,7 +27,7 @@ class GroupsView(SCIMObjectView):
def group_to_scim(self, scim_group: SCIMSourceGroup) -> dict: def group_to_scim(self, scim_group: SCIMSourceGroup) -> dict:
"""Convert Group to SCIM data""" """Convert Group to SCIM data"""
payload = SCIMGroupModel( payload = SCIMGroupModel(
schemas=[SCIM_GROUP_SCHEMA], schemas=[SCIM_USER_SCHEMA],
id=str(scim_group.group.pk), id=str(scim_group.group.pk),
externalId=scim_group.id, externalId=scim_group.id,
displayName=scim_group.group.name, displayName=scim_group.group.name,
@ -64,7 +58,7 @@ class GroupsView(SCIMObjectView):
if group_id: if group_id:
connection = base_query.filter(source=self.source, group__group_uuid=group_id).first() connection = base_query.filter(source=self.source, group__group_uuid=group_id).first()
if not connection: if not connection:
raise SCIMNotFoundError("Group not found.") raise Http404
return Response(self.group_to_scim(connection)) return Response(self.group_to_scim(connection))
connections = ( connections = (
base_query.filter(source=self.source).order_by("pk").filter(self.filter_parse(request)) base_query.filter(source=self.source).order_by("pk").filter(self.filter_parse(request))
@ -125,7 +119,7 @@ class GroupsView(SCIMObjectView):
).first() ).first()
if connection: if connection:
self.logger.debug("Found existing group") self.logger.debug("Found existing group")
raise SCIMConflictError("Group with ID exists already.") return Response(status=409)
connection = self.update_group(None, request.data) connection = self.update_group(None, request.data)
return Response(self.group_to_scim(connection), status=201) return Response(self.group_to_scim(connection), status=201)
@ -135,44 +129,10 @@ class GroupsView(SCIMObjectView):
source=self.source, group__group_uuid=group_id source=self.source, group__group_uuid=group_id
).first() ).first()
if not connection: if not connection:
raise SCIMNotFoundError("Group not found.") raise Http404
connection = self.update_group(connection, request.data) connection = self.update_group(connection, request.data)
return Response(self.group_to_scim(connection), status=200) return Response(self.group_to_scim(connection), status=200)
@atomic
def patch(self, request: Request, group_id: str, **kwargs) -> Response:
"""Patch group handler"""
connection = SCIMSourceGroup.objects.filter(
source=self.source, group__group_uuid=group_id
).first()
if not connection:
raise SCIMNotFoundError("Group not found.")
for _op in request.data.get("Operations", []):
operation = PatchOperation.model_validate(_op)
if operation.op.lower() not in ["add", "remove", "replace"]:
raise SCIMValidationError()
attr_path = AttrPath(f'{operation.path} eq ""', {})
if attr_path.first_path == ("members", None, None):
# FIXME: this can probably be de-duplicated
if operation.op == PatchOp.add:
if not isinstance(operation.value, list):
operation.value = [operation.value]
query = Q()
for member in operation.value:
query |= Q(uuid=member["value"])
if query:
connection.group.users.add(*User.objects.filter(query))
elif operation.op == PatchOp.remove:
if not isinstance(operation.value, list):
operation.value = [operation.value]
query = Q()
for member in operation.value:
query |= Q(uuid=member["value"])
if query:
connection.group.users.remove(*User.objects.filter(query))
return Response(self.group_to_scim(connection), status=200)
@atomic @atomic
def delete(self, request: Request, group_id: str, **kwargs) -> Response: def delete(self, request: Request, group_id: str, **kwargs) -> Response:
"""Delete group handler""" """Delete group handler"""
@ -180,7 +140,7 @@ class GroupsView(SCIMObjectView):
source=self.source, group__group_uuid=group_id source=self.source, group__group_uuid=group_id
).first() ).first()
if not connection: if not connection:
raise SCIMNotFoundError("Group not found.") raise Http404
connection.group.delete() connection.group.delete()
connection.delete() connection.delete()
return Response(status=204) return Response(status=204)

View File

@ -1,11 +1,11 @@
"""SCIM Meta views""" """SCIM Meta views"""
from django.http import Http404
from django.urls import reverse from django.urls import reverse
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from authentik.sources.scim.views.v2.base import SCIMView from authentik.sources.scim.views.v2.base import SCIMView
from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError
class ResourceTypesView(SCIMView): class ResourceTypesView(SCIMView):
@ -138,7 +138,7 @@ class ResourceTypesView(SCIMView):
resource = [x for x in resource_types if x.get("id") == resource_type] resource = [x for x in resource_types if x.get("id") == resource_type]
if resource: if resource:
return Response(resource[0]) return Response(resource[0])
raise SCIMNotFoundError("Resource not found.") raise Http404
return Response( return Response(
{ {
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],

View File

@ -3,12 +3,12 @@
from json import loads from json import loads
from django.conf import settings from django.conf import settings
from django.http import Http404
from django.urls import reverse from django.urls import reverse
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from authentik.sources.scim.views.v2.base import SCIMView from authentik.sources.scim.views.v2.base import SCIMView
from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError
with open( with open(
settings.BASE_DIR / "authentik" / "sources" / "scim" / "schemas" / "schema.json", settings.BASE_DIR / "authentik" / "sources" / "scim" / "schemas" / "schema.json",
@ -44,7 +44,7 @@ class SchemaView(SCIMView):
schema = [x for x in schemas if x.get("id") == schema_uri] schema = [x for x in schemas if x.get("id") == schema_uri]
if schema: if schema:
return Response(schema[0]) return Response(schema[0])
raise SCIMNotFoundError("Schema not found.") raise Http404
return Response( return Response(
{ {
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"], "schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],

View File

@ -33,8 +33,6 @@ class ServiceProviderConfigView(SCIMView):
{ {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"], "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"],
"authenticationSchemes": auth_schemas, "authenticationSchemes": auth_schemas,
# We only support patch for groups currently, so don't broadly advertise it.
# Implementations that require Group patch will use it regardless of this flag.
"patch": {"supported": False}, "patch": {"supported": False},
"bulk": {"supported": False, "maxOperations": 0, "maxPayloadSize": 0}, "bulk": {"supported": False, "maxOperations": 0, "maxPayloadSize": 0},
"filter": { "filter": {

View File

@ -4,7 +4,7 @@ from uuid import uuid4
from django.db.models import Q from django.db.models import Q
from django.db.transaction import atomic from django.db.transaction import atomic
from django.http import QueryDict from django.http import Http404, QueryDict
from django.urls import reverse from django.urls import reverse
from pydanticscim.user import Email, EmailKind, Name from pydanticscim.user import Email, EmailKind, Name
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
@ -16,7 +16,6 @@ from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
from authentik.providers.scim.clients.schema import User as SCIMUserModel from authentik.providers.scim.clients.schema import User as SCIMUserModel
from authentik.sources.scim.models import SCIMSourceUser from authentik.sources.scim.models import SCIMSourceUser
from authentik.sources.scim.views.v2.base import SCIMObjectView from authentik.sources.scim.views.v2.base import SCIMObjectView
from authentik.sources.scim.views.v2.exceptions import SCIMConflictError, SCIMNotFoundError
class UsersView(SCIMObjectView): class UsersView(SCIMObjectView):
@ -70,7 +69,7 @@ class UsersView(SCIMObjectView):
.first() .first()
) )
if not connection: if not connection:
raise SCIMNotFoundError("User not found.") raise Http404
return Response(self.user_to_scim(connection)) return Response(self.user_to_scim(connection))
connections = ( connections = (
SCIMSourceUser.objects.filter(source=self.source).select_related("user").order_by("pk") SCIMSourceUser.objects.filter(source=self.source).select_related("user").order_by("pk")
@ -123,7 +122,7 @@ class UsersView(SCIMObjectView):
).first() ).first()
if connection: if connection:
self.logger.debug("Found existing user") self.logger.debug("Found existing user")
raise SCIMConflictError("Group with ID exists already.") return Response(status=409)
connection = self.update_user(None, request.data) connection = self.update_user(None, request.data)
return Response(self.user_to_scim(connection), status=201) return Response(self.user_to_scim(connection), status=201)
@ -131,7 +130,7 @@ class UsersView(SCIMObjectView):
"""Update user handler""" """Update user handler"""
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first() connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
if not connection: if not connection:
raise SCIMNotFoundError("User not found.") raise Http404
self.update_user(connection, request.data) self.update_user(connection, request.data)
return Response(self.user_to_scim(connection), status=200) return Response(self.user_to_scim(connection), status=200)
@ -140,7 +139,7 @@ class UsersView(SCIMObjectView):
"""Delete user handler""" """Delete user handler"""
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first() connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
if not connection: if not connection:
raise SCIMNotFoundError("User not found.") raise Http404
connection.user.delete() connection.user.delete()
connection.delete() connection.delete()
return Response(status=204) return Response(status=204)

View File

@ -1,7 +1,6 @@
"""Validation stage challenge checking""" """Validation stage challenge checking"""
from json import loads from json import loads
from typing import TYPE_CHECKING
from urllib.parse import urlencode from urllib.parse import urlencode
from django.http import HttpRequest from django.http import HttpRequest
@ -37,12 +36,10 @@ from authentik.stages.authenticator_email.models import EmailDevice
from authentik.stages.authenticator_sms.models import SMSDevice from authentik.stages.authenticator_sms.models import SMSDevice
from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses
from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice
from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id
LOGGER = get_logger() LOGGER = get_logger()
if TYPE_CHECKING:
from authentik.stages.authenticator_validate.stage import AuthenticatorValidateStageView
class DeviceChallenge(PassiveSerializer): class DeviceChallenge(PassiveSerializer):
@ -55,11 +52,11 @@ class DeviceChallenge(PassiveSerializer):
def get_challenge_for_device( def get_challenge_for_device(
stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage, device: Device request: HttpRequest, stage: AuthenticatorValidateStage, device: Device
) -> dict: ) -> dict:
"""Generate challenge for a single device""" """Generate challenge for a single device"""
if isinstance(device, WebAuthnDevice): if isinstance(device, WebAuthnDevice):
return get_webauthn_challenge(stage_view, stage, device) return get_webauthn_challenge(request, stage, device)
if isinstance(device, EmailDevice): if isinstance(device, EmailDevice):
return {"email": mask_email(device.email)} return {"email": mask_email(device.email)}
# Code-based challenges have no hints # Code-based challenges have no hints
@ -67,30 +64,26 @@ def get_challenge_for_device(
def get_webauthn_challenge_without_user( def get_webauthn_challenge_without_user(
stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage request: HttpRequest, stage: AuthenticatorValidateStage
) -> dict: ) -> dict:
"""Same as `get_webauthn_challenge`, but allows any client device. We can then later check """Same as `get_webauthn_challenge`, but allows any client device. We can then later check
who the device belongs to.""" who the device belongs to."""
stage_view.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None) request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
authentication_options = generate_authentication_options( authentication_options = generate_authentication_options(
rp_id=get_rp_id(stage_view.request), rp_id=get_rp_id(request),
allow_credentials=[], allow_credentials=[],
user_verification=UserVerificationRequirement(stage.webauthn_user_verification), user_verification=UserVerificationRequirement(stage.webauthn_user_verification),
) )
stage_view.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = ( request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge
authentication_options.challenge
)
return loads(options_to_json(authentication_options)) return loads(options_to_json(authentication_options))
def get_webauthn_challenge( def get_webauthn_challenge(
stage_view: "AuthenticatorValidateStageView", request: HttpRequest, stage: AuthenticatorValidateStage, device: WebAuthnDevice | None = None
stage: AuthenticatorValidateStage,
device: WebAuthnDevice | None = None,
) -> dict: ) -> dict:
"""Send the client a challenge that we'll check later""" """Send the client a challenge that we'll check later"""
stage_view.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None) request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
allowed_credentials = [] allowed_credentials = []
@ -101,14 +94,12 @@ def get_webauthn_challenge(
allowed_credentials.append(user_device.descriptor) allowed_credentials.append(user_device.descriptor)
authentication_options = generate_authentication_options( authentication_options = generate_authentication_options(
rp_id=get_rp_id(stage_view.request), rp_id=get_rp_id(request),
allow_credentials=allowed_credentials, allow_credentials=allowed_credentials,
user_verification=UserVerificationRequirement(stage.webauthn_user_verification), user_verification=UserVerificationRequirement(stage.webauthn_user_verification),
) )
stage_view.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = ( request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge
authentication_options.challenge
)
return loads(options_to_json(authentication_options)) return loads(options_to_json(authentication_options))
@ -155,7 +146,7 @@ def validate_challenge_code(code: str, stage_view: StageView, user: User) -> Dev
def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -> Device: def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -> Device:
"""Validate WebAuthn Challenge""" """Validate WebAuthn Challenge"""
request = stage_view.request request = stage_view.request
challenge = stage_view.executor.plan.context.get(PLAN_CONTEXT_WEBAUTHN_CHALLENGE) challenge = request.session.get(SESSION_KEY_WEBAUTHN_CHALLENGE)
stage: AuthenticatorValidateStage = stage_view.executor.current_stage stage: AuthenticatorValidateStage = stage_view.executor.current_stage
try: try:
credential = parse_authentication_credential_json(data) credential = parse_authentication_credential_json(data)

View File

@ -224,7 +224,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
data={ data={
"device_class": device_class, "device_class": device_class,
"device_uid": device.pk, "device_uid": device.pk,
"challenge": get_challenge_for_device(self, stage, device), "challenge": get_challenge_for_device(self.request, stage, device),
"last_used": device.last_used, "last_used": device.last_used,
} }
) )
@ -243,7 +243,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
"device_class": DeviceClasses.WEBAUTHN, "device_class": DeviceClasses.WEBAUTHN,
"device_uid": -1, "device_uid": -1,
"challenge": get_webauthn_challenge_without_user( "challenge": get_webauthn_challenge_without_user(
self, self.request,
self.executor.current_stage, self.executor.current_stage,
), ),
"last_used": None, "last_used": None,

View File

@ -31,7 +31,7 @@ from authentik.stages.authenticator_webauthn.models import (
WebAuthnDevice, WebAuthnDevice,
WebAuthnDeviceType, WebAuthnDeviceType,
) )
from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import
from authentik.stages.identification.models import IdentificationStage, UserFields from authentik.stages.identification.models import IdentificationStage, UserFields
from authentik.stages.user_login.models import UserLoginStage from authentik.stages.user_login.models import UserLoginStage
@ -103,11 +103,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
device_classes=[DeviceClasses.WEBAUTHN], device_classes=[DeviceClasses.WEBAUTHN],
webauthn_user_verification=UserVerification.PREFERRED, webauthn_user_verification=UserVerification.PREFERRED,
) )
plan = FlowPlan("") challenge = get_challenge_for_device(request, stage, webauthn_device)
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request
)
challenge = get_challenge_for_device(stage_view, stage, webauthn_device)
del challenge["challenge"] del challenge["challenge"]
self.assertEqual( self.assertEqual(
challenge, challenge,
@ -126,9 +122,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
validate_challenge_webauthn( validate_challenge_webauthn(
{}, {}, StageView(FlowExecutorView(current_stage=stage), request=request), self.user
StageView(FlowExecutorView(current_stage=stage, plan=plan), request=request),
self.user,
) )
def test_device_challenge_webauthn_restricted(self): def test_device_challenge_webauthn_restricted(self):
@ -199,35 +193,22 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
sign_count=0, sign_count=0,
rp_id=generate_id(), rp_id=generate_id(),
) )
plan = FlowPlan("") challenge = get_challenge_for_device(request, stage, webauthn_device)
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes( webauthn_challenge = request.session[SESSION_KEY_WEBAUTHN_CHALLENGE]
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
)
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request
)
challenge = get_challenge_for_device(stage_view, stage, webauthn_device)
self.assertEqual( self.assertEqual(
challenge["allowCredentials"], challenge,
[ {
{ "allowCredentials": [
"id": "QKZ97ASJAOIDyipAs6mKUxDUZgDrWrbAsUb5leL7-oU", {
"type": "public-key", "id": "QKZ97ASJAOIDyipAs6mKUxDUZgDrWrbAsUb5leL7-oU",
} "type": "public-key",
], }
) ],
self.assertIsNotNone(challenge["challenge"]) "challenge": bytes_to_base64url(webauthn_challenge),
self.assertEqual( "rpId": "testserver",
challenge["rpId"], "timeout": 60000,
"testserver", "userVerification": "preferred",
) },
self.assertEqual(
challenge["timeout"],
60000,
)
self.assertEqual(
challenge["userVerification"],
"preferred",
) )
def test_get_challenge_userless(self): def test_get_challenge_userless(self):
@ -247,16 +228,18 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
sign_count=0, sign_count=0,
rp_id=generate_id(), rp_id=generate_id(),
) )
plan = FlowPlan("") challenge = get_webauthn_challenge_without_user(request, stage)
stage_view = AuthenticatorValidateStageView( webauthn_challenge = request.session[SESSION_KEY_WEBAUTHN_CHALLENGE]
FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request self.assertEqual(
challenge,
{
"allowCredentials": [],
"challenge": bytes_to_base64url(webauthn_challenge),
"rpId": "testserver",
"timeout": 60000,
"userVerification": "preferred",
},
) )
challenge = get_webauthn_challenge_without_user(stage_view, stage)
self.assertEqual(challenge["allowCredentials"], [])
self.assertIsNotNone(challenge["challenge"])
self.assertEqual(challenge["rpId"], "testserver")
self.assertEqual(challenge["timeout"], 60000)
self.assertEqual(challenge["userVerification"], "preferred")
def test_validate_challenge_unrestricted(self): def test_validate_challenge_unrestricted(self):
"""Test webauthn authentication (unrestricted webauthn device)""" """Test webauthn authentication (unrestricted webauthn device)"""
@ -292,10 +275,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
"last_used": None, "last_used": None,
} }
] ]
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes( session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ" "aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ"
) )
session[SESSION_KEY_PLAN] = plan
session.save() session.save()
response = self.client.post( response = self.client.post(
@ -369,10 +352,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
"last_used": None, "last_used": None,
} }
] ]
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes( session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ" "aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ"
) )
session[SESSION_KEY_PLAN] = plan
session.save() session.save()
response = self.client.post( response = self.client.post(
@ -450,10 +433,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
"last_used": None, "last_used": None,
} }
] ]
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes( session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA" "g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
) )
session[SESSION_KEY_PLAN] = plan
session.save() session.save()
response = self.client.post( response = self.client.post(
@ -513,14 +496,17 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
not_configured_action=NotConfiguredAction.CONFIGURE, not_configured_action=NotConfiguredAction.CONFIGURE,
device_classes=[DeviceClasses.WEBAUTHN], device_classes=[DeviceClasses.WEBAUTHN],
) )
plan = FlowPlan(flow.pk.hex) stage_view = AuthenticatorValidateStageView(
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes( FlowExecutorView(flow=flow, current_stage=stage), request=request
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
) )
request = get_request("/") request = get_request("/")
request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
)
request.session.save()
stage_view = AuthenticatorValidateStageView( stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=flow, current_stage=stage, plan=plan), request=request FlowExecutorView(flow=flow, current_stage=stage), request=request
) )
request.META["SERVER_NAME"] = "localhost" request.META["SERVER_NAME"] = "localhost"
request.META["SERVER_PORT"] = "9000" request.META["SERVER_PORT"] = "9000"

View File

@ -25,7 +25,6 @@ class AuthenticatorWebAuthnStageSerializer(StageSerializer):
"resident_key_requirement", "resident_key_requirement",
"device_type_restrictions", "device_type_restrictions",
"device_type_restrictions_obj", "device_type_restrictions_obj",
"max_attempts",
] ]

File diff suppressed because one or more lines are too long

View File

@ -1,21 +0,0 @@
# Generated by Django 5.1.11 on 2025-06-13 22:41
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
(
"authentik_stages_authenticator_webauthn",
"0012_webauthndevice_created_webauthndevice_last_updated_and_more",
),
]
operations = [
migrations.AddField(
model_name="authenticatorwebauthnstage",
name="max_attempts",
field=models.PositiveIntegerField(default=0),
),
]

View File

@ -84,8 +84,6 @@ class AuthenticatorWebAuthnStage(ConfigurableStage, FriendlyNamedStage, Stage):
device_type_restrictions = models.ManyToManyField("WebAuthnDeviceType", blank=True) device_type_restrictions = models.ManyToManyField("WebAuthnDeviceType", blank=True)
max_attempts = models.PositiveIntegerField(default=0)
@property @property
def serializer(self) -> type[BaseSerializer]: def serializer(self) -> type[BaseSerializer]:
from authentik.stages.authenticator_webauthn.api.stages import ( from authentik.stages.authenticator_webauthn.api.stages import (

View File

@ -5,13 +5,12 @@ from uuid import UUID
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.http.request import QueryDict from django.http.request import QueryDict
from django.utils.translation import gettext as __
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework.fields import CharField from rest_framework.fields import CharField
from rest_framework.serializers import ValidationError from rest_framework.serializers import ValidationError
from webauthn import options_to_json from webauthn import options_to_json
from webauthn.helpers.bytes_to_base64url import bytes_to_base64url from webauthn.helpers.bytes_to_base64url import bytes_to_base64url
from webauthn.helpers.exceptions import WebAuthnException from webauthn.helpers.exceptions import InvalidRegistrationResponse
from webauthn.helpers.structs import ( from webauthn.helpers.structs import (
AttestationConveyancePreference, AttestationConveyancePreference,
AuthenticatorAttachment, AuthenticatorAttachment,
@ -42,8 +41,7 @@ from authentik.stages.authenticator_webauthn.models import (
) )
from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id
PLAN_CONTEXT_WEBAUTHN_CHALLENGE = "goauthentik.io/stages/authenticator_webauthn/challenge" SESSION_KEY_WEBAUTHN_CHALLENGE = "authentik/stages/authenticator_webauthn/challenge"
PLAN_CONTEXT_WEBAUTHN_ATTEMPT = "goauthentik.io/stages/authenticator_webauthn/attempt"
class AuthenticatorWebAuthnChallenge(WithUserInfoChallenge): class AuthenticatorWebAuthnChallenge(WithUserInfoChallenge):
@ -64,7 +62,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse):
def validate_response(self, response: dict) -> dict: def validate_response(self, response: dict) -> dict:
"""Validate webauthn challenge response""" """Validate webauthn challenge response"""
challenge = self.stage.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] challenge = self.request.session[SESSION_KEY_WEBAUTHN_CHALLENGE]
try: try:
registration: VerifiedRegistration = verify_registration_response( registration: VerifiedRegistration = verify_registration_response(
@ -73,7 +71,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse):
expected_rp_id=get_rp_id(self.request), expected_rp_id=get_rp_id(self.request),
expected_origin=get_origin(self.request), expected_origin=get_origin(self.request),
) )
except WebAuthnException as exc: except InvalidRegistrationResponse as exc:
self.stage.logger.warning("registration failed", exc=exc) self.stage.logger.warning("registration failed", exc=exc)
raise ValidationError(f"Registration failed. Error: {exc}") from None raise ValidationError(f"Registration failed. Error: {exc}") from None
@ -116,10 +114,9 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
response_class = AuthenticatorWebAuthnChallengeResponse response_class = AuthenticatorWebAuthnChallengeResponse
def get_challenge(self, *args, **kwargs) -> Challenge: def get_challenge(self, *args, **kwargs) -> Challenge:
# clear session variables prior to starting a new registration
self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
stage: AuthenticatorWebAuthnStage = self.executor.current_stage stage: AuthenticatorWebAuthnStage = self.executor.current_stage
self.executor.plan.context.setdefault(PLAN_CONTEXT_WEBAUTHN_ATTEMPT, 0)
# clear flow variables prior to starting a new registration
self.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None)
user = self.get_pending_user() user = self.get_pending_user()
# library accepts none so we store null in the database, but if there is a value # library accepts none so we store null in the database, but if there is a value
@ -142,7 +139,8 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
attestation=AttestationConveyancePreference.DIRECT, attestation=AttestationConveyancePreference.DIRECT,
) )
self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = registration_options.challenge self.request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = registration_options.challenge
self.request.session.save()
return AuthenticatorWebAuthnChallenge( return AuthenticatorWebAuthnChallenge(
data={ data={
"registration": loads(options_to_json(registration_options)), "registration": loads(options_to_json(registration_options)),
@ -155,24 +153,6 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
response.user = self.get_pending_user() response.user = self.get_pending_user()
return response return response
def challenge_invalid(self, response):
stage: AuthenticatorWebAuthnStage = self.executor.current_stage
self.executor.plan.context.setdefault(PLAN_CONTEXT_WEBAUTHN_ATTEMPT, 0)
self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_ATTEMPT] += 1
if (
stage.max_attempts > 0
and self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_ATTEMPT] >= stage.max_attempts
):
return self.executor.stage_invalid(
__(
"Exceeded maximum attempts. "
"Contact your {brand} administrator for help.".format(
brand=self.request.brand.branding_title
)
)
)
return super().challenge_invalid(response)
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
# Webauthn Challenge has already been validated # Webauthn Challenge has already been validated
webauthn_credential: VerifiedRegistration = response.validated_data["response"] webauthn_credential: VerifiedRegistration = response.validated_data["response"]
@ -199,3 +179,6 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
else: else:
return self.executor.stage_invalid("Device with Credential ID already exists.") return self.executor.stage_invalid("Device with Credential ID already exists.")
return self.executor.stage_ok() return self.executor.stage_ok()
def cleanup(self):
self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)

View File

@ -18,7 +18,7 @@ from authentik.stages.authenticator_webauthn.models import (
WebAuthnDevice, WebAuthnDevice,
WebAuthnDeviceType, WebAuthnDeviceType,
) )
from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import
@ -57,9 +57,6 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
response = self.client.get( response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
) )
plan: FlowPlan = self.client.session[SESSION_KEY_PLAN]
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
session = self.client.session session = self.client.session
self.assertStageResponse( self.assertStageResponse(
@ -73,7 +70,7 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
"name": self.user.username, "name": self.user.username,
"displayName": self.user.name, "displayName": self.user.name,
}, },
"challenge": bytes_to_base64url(plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE]), "challenge": bytes_to_base64url(session[SESSION_KEY_WEBAUTHN_CHALLENGE]),
"pubKeyCredParams": [ "pubKeyCredParams": [
{"type": "public-key", "alg": -7}, {"type": "public-key", "alg": -7},
{"type": "public-key", "alg": -8}, {"type": "public-key", "alg": -8},
@ -100,11 +97,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
"""Test registration""" """Test registration"""
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session = self.client.session session = self.client.session
session[SESSION_KEY_PLAN] = plan session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session.save() session.save()
response = self.client.post( response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
@ -149,11 +146,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session = self.client.session session = self.client.session
session[SESSION_KEY_PLAN] = plan session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session.save() session.save()
response = self.client.post( response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
@ -212,11 +209,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session = self.client.session session = self.client.session
session[SESSION_KEY_PLAN] = plan session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session.save() session.save()
response = self.client.post( response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
@ -262,11 +259,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()]) plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session = self.client.session session = self.client.session
session[SESSION_KEY_PLAN] = plan session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session.save() session.save()
response = self.client.post( response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
@ -301,109 +298,3 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertStageRedirects(response, reverse("authentik_core:root-redirect")) self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
self.assertTrue(WebAuthnDevice.objects.filter(user=self.user).exists()) self.assertTrue(WebAuthnDevice.objects.filter(user=self.user).exists())
def test_register_max_retries(self):
"""Test registration (exceeding max retries)"""
self.stage.max_attempts = 2
self.stage.save()
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session.save()
# first failed request
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
data={
"component": "ak-stage-authenticator-webauthn",
"response": {
"id": "kqnmrVLnDG-OwsSNHkihYZaNz5s",
"rawId": "kqnmrVLnDG-OwsSNHkihYZaNz5s",
"type": "public-key",
"registrationClientExtensions": "{}",
"response": {
"clientDataJSON": (
"eyJ0eXBlIjoid2ViYXV0aG4uY3JlYXRlIiwiY2hhbGxlbmd"
"lIjoiMDNYb2RpNTRnS3NmblA1STlWRmZoYUdYVlZFMk5VeV"
"pwQkJYbnNfSkkteDZWOVJZMlR3MlFteFJKa2hoNzE3NEVrU"
"mF6VW50SXdqTVZZOWJGRzYwTHciLCJvcmlnaW4iOiJodHRw"
"Oi8vbG9jYWxob3N0OjkwMDAiLCJjcm9zc09yaWdpbiI6ZmF"
),
"attestationObject": (
"o2NmbXRkbm9uZWdhdHRTdG10oGhhdXRoRGF0YViYSZYN5Yg"
"OjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2NdAAAAAPv8MA"
"cVTk7MjAtuAgVX170AFJKp5q1S5wxvjsLEjR5IoWGWjc-bp"
"QECAyYgASFYIKtcZHPumH37XHs0IM1v3pUBRIqHVV_SE-Lq"
"2zpJAOVXIlgg74Fg_WdB0kuLYqCKbxogkEPaVtR_iR3IyQFIJAXBzds"
),
},
},
},
SERVER_NAME="localhost",
SERVER_PORT="9000",
)
self.assertEqual(response.status_code, 200)
self.assertStageResponse(
response,
flow=self.flow,
component="ak-stage-authenticator-webauthn",
response_errors={
"response": [
{
"string": (
"Registration failed. Error: Unable to decode "
"client_data_json bytes as JSON"
),
"code": "invalid",
}
]
},
)
self.assertFalse(WebAuthnDevice.objects.filter(user=self.user).exists())
# Second failed request
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
data={
"component": "ak-stage-authenticator-webauthn",
"response": {
"id": "kqnmrVLnDG-OwsSNHkihYZaNz5s",
"rawId": "kqnmrVLnDG-OwsSNHkihYZaNz5s",
"type": "public-key",
"registrationClientExtensions": "{}",
"response": {
"clientDataJSON": (
"eyJ0eXBlIjoid2ViYXV0aG4uY3JlYXRlIiwiY2hhbGxlbmd"
"lIjoiMDNYb2RpNTRnS3NmblA1STlWRmZoYUdYVlZFMk5VeV"
"pwQkJYbnNfSkkteDZWOVJZMlR3MlFteFJKa2hoNzE3NEVrU"
"mF6VW50SXdqTVZZOWJGRzYwTHciLCJvcmlnaW4iOiJodHRw"
"Oi8vbG9jYWxob3N0OjkwMDAiLCJjcm9zc09yaWdpbiI6ZmF"
),
"attestationObject": (
"o2NmbXRkbm9uZWdhdHRTdG10oGhhdXRoRGF0YViYSZYN5Yg"
"OjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2NdAAAAAPv8MA"
"cVTk7MjAtuAgVX170AFJKp5q1S5wxvjsLEjR5IoWGWjc-bp"
"QECAyYgASFYIKtcZHPumH37XHs0IM1v3pUBRIqHVV_SE-Lq"
"2zpJAOVXIlgg74Fg_WdB0kuLYqCKbxogkEPaVtR_iR3IyQFIJAXBzds"
),
},
},
},
SERVER_NAME="localhost",
SERVER_PORT="9000",
)
self.assertEqual(response.status_code, 200)
self.assertStageResponse(
response,
flow=self.flow,
component="ak-stage-access-denied",
error_message=(
"Exceeded maximum attempts. Contact your authentik administrator for help."
),
)
self.assertFalse(WebAuthnDevice.objects.filter(user=self.user).exists())

Some files were not shown because too many files have changed in this diff Show More