Compare commits

..

4 Commits

918 changed files with 20979 additions and 22725 deletions

View File

@ -1,5 +1,5 @@
[bumpversion]
current_version = 2025.6.2
current_version = 2025.6.1
tag = True
commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
@ -21,8 +21,6 @@ optional_value = final
[bumpversion:file:package.json]
[bumpversion:file:package-lock.json]
[bumpversion:file:docker-compose.yml]
[bumpversion:file:schema.yml]
@ -33,4 +31,6 @@ optional_value = final
[bumpversion:file:internal/constants/constants.go]
[bumpversion:file:web/src/common/constants.ts]
[bumpversion:file:lifecycle/aws/template.yaml]

View File

@ -7,9 +7,6 @@ charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[*.toml]
indent_size = 2
[*.html]
indent_size = 2

View File

@ -15,8 +15,8 @@ jobs:
matrix:
version:
- docs
- version-2025-4
- version-2025-2
- version-2024-12
steps:
- uses: actions/checkout@v4
- run: |

View File

@ -202,7 +202,7 @@ jobs:
uses: actions/cache@v4
with:
path: web/dist
key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b
key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b
- name: prepare web ui
if: steps.cache-web.outputs.cache-hit != 'true'
working-directory: web

View File

@ -41,27 +41,6 @@ jobs:
- name: test
working-directory: website/
run: npm test
build:
runs-on: ubuntu-latest
name: ${{ matrix.job }}
strategy:
fail-fast: false
matrix:
job:
- build
- build:integrations
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version-file: website/package.json
cache: "npm"
cache-dependency-path: website/package-lock.json
- working-directory: website/
run: npm ci
- name: build
working-directory: website/
run: npm run ${{ matrix.job }}
build-container:
runs-on: ubuntu-latest
permissions:
@ -115,7 +94,6 @@ jobs:
needs:
- lint
- test
- build
- build-container
runs-on: ubuntu-latest
steps:

View File

@ -2,7 +2,7 @@ name: "CodeQL"
on:
push:
branches: [main, next, version*]
branches: [main, "*", next, version*]
pull_request:
branches: [main]
schedule:

View File

@ -6,15 +6,13 @@
"!Context scalar",
"!Enumerate sequence",
"!Env scalar",
"!Env sequence",
"!Find sequence",
"!Format sequence",
"!If sequence",
"!Index scalar",
"!KeyOf scalar",
"!Value scalar",
"!AtIndex scalar",
"!ParseJSON scalar"
"!AtIndex scalar"
],
"typescript.preferences.importModuleSpecifier": "non-relative",
"typescript.preferences.importModuleSpecifierEnding": "index",

View File

@ -75,9 +75,9 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
/bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
# Stage 4: Download uv
FROM ghcr.io/astral-sh/uv:0.7.15 AS uv
FROM ghcr.io/astral-sh/uv:0.7.12 AS uv
# Stage 5: Base python image
FROM ghcr.io/goauthentik/fips-python:3.13.5-slim-bookworm-fips AS python-base
FROM ghcr.io/goauthentik/fips-python:3.13.4-slim-bookworm-fips AS python-base
ENV VENV_PATH="/ak-root/.venv" \
PATH="/lifecycle:/ak-root/.venv/bin:$PATH" \

View File

@ -86,10 +86,6 @@ dev-create-db:
dev-reset: dev-drop-db dev-create-db migrate ## Drop and restore the Authentik PostgreSQL instance to a "fresh install" state.
update-test-mmdb: ## Update test GeoIP and ASN Databases
curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-ASN-Test.mmdb -o ${PWD}/tests/GeoLite2-ASN-Test.mmdb
curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-City-Test.mmdb -o ${PWD}/tests/GeoLite2-City-Test.mmdb
#########################
## API Schema
#########################
@ -98,7 +94,7 @@ gen-build: ## Extract the schema from the database
AUTHENTIK_DEBUG=true \
AUTHENTIK_TENANTS__ENABLED=true \
AUTHENTIK_OUTPOSTS__DISABLE_EMBEDDED_OUTPOST=true \
uv run ak make_blueprint_schema --file blueprints/schema.json
uv run ak make_blueprint_schema > blueprints/schema.json
AUTHENTIK_DEBUG=true \
AUTHENTIK_TENANTS__ENABLED=true \
AUTHENTIK_OUTPOSTS__DISABLE_EMBEDDED_OUTPOST=true \
@ -150,9 +146,9 @@ gen-client-ts: gen-clean-ts ## Build and install the authentik API for Typescri
--additional-properties=npmVersion=${NPM_VERSION} \
--git-repo-id authentik \
--git-user-id goauthentik
cd ${PWD}/${GEN_API_TS} && npm link
cd ${PWD}/web && npm link @goauthentik/api
mkdir -p web/node_modules/@goauthentik/api
cd ${PWD}/${GEN_API_TS} && npm i
\cp -rf ${PWD}/${GEN_API_TS}/* web/node_modules/@goauthentik/api
gen-client-py: gen-clean-py ## Build and install the authentik API for Python
docker run \

View File

@ -2,7 +2,7 @@
from os import environ
__version__ = "2025.6.2"
__version__ = "2025.6.1"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View File

@ -72,33 +72,20 @@ class Command(BaseCommand):
"additionalProperties": True,
},
"entries": {
"anyOf": [
{
"type": "array",
"items": {"$ref": "#/$defs/blueprint_entry"},
},
{
"type": "object",
"additionalProperties": {
"type": "array",
"items": {"$ref": "#/$defs/blueprint_entry"},
},
},
],
"type": "array",
"items": {
"oneOf": [],
},
},
},
"$defs": {"blueprint_entry": {"oneOf": []}},
"$defs": {},
}
def add_arguments(self, parser):
parser.add_argument("--file", type=str)
@no_translations
def handle(self, *args, file: str, **options):
def handle(self, *args, **options):
"""Generate JSON Schema for blueprints"""
self.build()
with open(file, "w") as _schema:
_schema.write(dumps(self.schema, indent=4, default=Command.json_default))
self.stdout.write(dumps(self.schema, indent=4, default=Command.json_default))
@staticmethod
def json_default(value: Any) -> Any:
@ -125,7 +112,7 @@ class Command(BaseCommand):
}
)
model_path = f"{model._meta.app_label}.{model._meta.model_name}"
self.schema["$defs"]["blueprint_entry"]["oneOf"].append(
self.schema["properties"]["entries"]["items"]["oneOf"].append(
self.template_entry(model_path, model, serializer)
)
@ -147,7 +134,7 @@ class Command(BaseCommand):
"id": {"type": "string"},
"state": {
"type": "string",
"enum": sorted([s.value for s in BlueprintEntryDesiredState]),
"enum": [s.value for s in BlueprintEntryDesiredState],
"default": "present",
},
"conditions": {"type": "array", "items": {"type": "boolean"}},
@ -218,7 +205,7 @@ class Command(BaseCommand):
"type": "object",
"required": ["permission"],
"properties": {
"permission": {"type": "string", "enum": sorted(perms)},
"permission": {"type": "string", "enum": perms},
"user": {"type": "integer"},
"role": {"type": "string"},
},

View File

@ -1,11 +1,10 @@
version: 1
entries:
foo:
- identifiers:
name: "%(id)s"
slug: "%(id)s"
model: authentik_flows.flow
state: present
attrs:
designation: stage_configuration
title: foo
- identifiers:
name: "%(id)s"
slug: "%(id)s"
model: authentik_flows.flow
state: present
attrs:
designation: stage_configuration
title: foo

View File

@ -37,7 +37,6 @@ entries:
- attrs:
attributes:
env_null: !Env [bar-baz, null]
json_parse: !ParseJSON '{"foo": "bar"}'
policy_pk1:
!Format [
"%s-%s",

View File

@ -35,6 +35,6 @@ def blueprint_tester(file_name: Path) -> Callable:
for blueprint_file in Path("blueprints/").glob("**/*.yaml"):
if "local" in str(blueprint_file) or "testing" in str(blueprint_file):
if "local" in str(blueprint_file):
continue
setattr(TestPackaged, f"test_blueprint_{blueprint_file}", blueprint_tester(blueprint_file))

View File

@ -215,7 +215,6 @@ class TestBlueprintsV1(TransactionTestCase):
},
"nested_context": "context-nested-value",
"env_null": None,
"json_parse": {"foo": "bar"},
"at_index_sequence": "foo",
"at_index_sequence_default": "non existent",
"at_index_mapping": 2,

View File

@ -6,7 +6,6 @@ from copy import copy
from dataclasses import asdict, dataclass, field, is_dataclass
from enum import Enum
from functools import reduce
from json import JSONDecodeError, loads
from operator import ixor
from os import getenv
from typing import Any, Literal, Union
@ -192,18 +191,11 @@ class Blueprint:
"""Dataclass used for a full export"""
version: int = field(default=1)
entries: list[BlueprintEntry] | dict[str, list[BlueprintEntry]] = field(default_factory=list)
entries: list[BlueprintEntry] = field(default_factory=list)
context: dict = field(default_factory=dict)
metadata: BlueprintMetadata | None = field(default=None)
def iter_entries(self) -> Iterable[BlueprintEntry]:
if isinstance(self.entries, dict):
for _section, entries in self.entries.items():
yield from entries
else:
yield from self.entries
class YAMLTag:
"""Base class for all YAML Tags"""
@ -234,7 +226,7 @@ class KeyOf(YAMLTag):
self.id_from = node.value
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
for _entry in blueprint.iter_entries():
for _entry in blueprint.entries:
if _entry.id == self.id_from and _entry._state.instance:
# Special handling for PolicyBindingModels, as they'll have a different PK
# which is used when creating policy bindings
@ -292,22 +284,6 @@ class Context(YAMLTag):
return value
class ParseJSON(YAMLTag):
"""Parse JSON from context/env/etc value"""
raw: str
def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None:
super().__init__()
self.raw = node.value
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
try:
return loads(self.raw)
except JSONDecodeError as exc:
raise EntryInvalidError.from_entry(exc, entry) from exc
class Format(YAMLTag):
"""Format a string"""
@ -683,7 +659,6 @@ class BlueprintLoader(SafeLoader):
self.add_constructor("!Value", Value)
self.add_constructor("!Index", Index)
self.add_constructor("!AtIndex", AtIndex)
self.add_constructor("!ParseJSON", ParseJSON)
class EntryInvalidError(SentryIgnoredException):

View File

@ -384,7 +384,7 @@ class Importer:
def _apply_models(self, raise_errors=False) -> bool:
"""Apply (create/update) models yaml"""
self.__pk_map = {}
for entry in self._import.iter_entries():
for entry in self._import.entries:
model_app_label, model_name = entry.get_model(self._import).split(".")
try:
model: type[SerializerModel] = registry.get_model(model_app_label, model_name)

View File

@ -47,7 +47,7 @@ class MetaModelRegistry:
models = apps.get_models()
for _, value in self.models.items():
models.append(value)
return sorted(models, key=str)
return models
def get_model(self, app_label: str, model_id: str) -> type[Model]:
"""Get model checks if any virtual models are registered, and falls back

View File

@ -1,6 +1,8 @@
"""Authenticator Devices API Views"""
from drf_spectacular.utils import extend_schema
from django.utils.translation import gettext_lazy as _
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema
from guardian.shortcuts import get_objects_for_user
from rest_framework.fields import (
BooleanField,
@ -13,7 +15,6 @@ from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ViewSet
from authentik.core.api.users import ParamUserSerializer
from authentik.core.api.utils import MetaNameSerializer
from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import EndpointDevice
from authentik.stages.authenticator import device_classes, devices_for_user
@ -22,7 +23,7 @@ from authentik.stages.authenticator_webauthn.models import WebAuthnDevice
class DeviceSerializer(MetaNameSerializer):
"""Serializer for authenticator devices"""
"""Serializer for Duo authenticator devices"""
pk = CharField()
name = CharField()
@ -32,27 +33,22 @@ class DeviceSerializer(MetaNameSerializer):
last_updated = DateTimeField(read_only=True)
last_used = DateTimeField(read_only=True, allow_null=True)
extra_description = SerializerMethodField()
external_id = SerializerMethodField()
def get_type(self, instance: Device) -> str:
"""Get type of device"""
return instance._meta.label
def get_extra_description(self, instance: Device) -> str | None:
def get_extra_description(self, instance: Device) -> str:
"""Get extra description"""
if isinstance(instance, WebAuthnDevice):
return instance.device_type.description if instance.device_type else None
return (
instance.device_type.description
if instance.device_type
else _("Extra description not available")
)
if isinstance(instance, EndpointDevice):
return instance.data.get("deviceSignals", {}).get("deviceModel")
return None
def get_external_id(self, instance: Device) -> str | None:
"""Get external Device ID"""
if isinstance(instance, WebAuthnDevice):
return instance.device_type.aaguid if instance.device_type else None
if isinstance(instance, EndpointDevice):
return instance.data.get("deviceSignals", {}).get("deviceModel")
return None
return ""
class DeviceViewSet(ViewSet):
@ -61,6 +57,7 @@ class DeviceViewSet(ViewSet):
serializer_class = DeviceSerializer
permission_classes = [IsAuthenticated]
@extend_schema(responses={200: DeviceSerializer(many=True)})
def list(self, request: Request) -> Response:
"""Get all devices for current user"""
devices = devices_for_user(request.user)
@ -82,11 +79,18 @@ class AdminDeviceViewSet(ViewSet):
yield from device_set
@extend_schema(
parameters=[ParamUserSerializer],
parameters=[
OpenApiParameter(
name="user",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.INT,
)
],
responses={200: DeviceSerializer(many=True)},
)
def list(self, request: Request) -> Response:
"""Get all devices for current user"""
args = ParamUserSerializer(data=request.query_params)
args.is_valid(raise_exception=True)
return Response(DeviceSerializer(self.get_devices(**args.validated_data), many=True).data)
kwargs = {}
if "user" in request.query_params:
kwargs = {"user": request.query_params["user"]}
return Response(DeviceSerializer(self.get_devices(**kwargs), many=True).data)

View File

@ -90,12 +90,6 @@ from authentik.stages.email.utils import TemplateEmailMessage
LOGGER = get_logger()
class ParamUserSerializer(PassiveSerializer):
"""Partial serializer for query parameters to select a user"""
user = PrimaryKeyRelatedField(queryset=User.objects.all().exclude_anonymous(), required=False)
class UserGroupSerializer(ModelSerializer):
"""Simplified Group Serializer for user's groups"""
@ -392,23 +386,8 @@ class UserViewSet(UsedByMixin, ModelViewSet):
queryset = User.objects.none()
ordering = ["username"]
serializer_class = UserSerializer
filterset_class = UsersFilter
search_fields = ["username", "name", "is_active", "email", "uuid", "attributes"]
def get_ql_fields(self):
from djangoql.schema import BoolField, StrField
from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
return [
StrField(User, "username"),
StrField(User, "name"),
StrField(User, "email"),
StrField(User, "path"),
BoolField(User, "is_active", nullable=True),
ChoiceSearchField(User, "type"),
JSONSearchField(User, "attributes", suggest_nested=False),
]
filterset_class = UsersFilter
def get_queryset(self):
base_qs = User.objects.all().exclude_anonymous()

View File

@ -2,7 +2,6 @@
from typing import Any
from django.db import models
from django.db.models import Model
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
from drf_spectacular.plumbing import build_basic_type
@ -31,27 +30,7 @@ def is_dict(value: Any):
raise ValidationError("Value must be a dictionary, and not have any duplicate keys.")
class JSONDictField(JSONField):
"""JSON Field which only allows dictionaries"""
default_validators = [is_dict]
class JSONExtension(OpenApiSerializerFieldExtension):
"""Generate API Schema for JSON fields as"""
target_class = "authentik.core.api.utils.JSONDictField"
def map_serializer_field(self, auto_schema, direction):
return build_basic_type(OpenApiTypes.OBJECT)
class ModelSerializer(BaseModelSerializer):
# By default, JSON fields we have are used to store dictionaries
serializer_field_mapping = BaseModelSerializer.serializer_field_mapping.copy()
serializer_field_mapping[models.JSONField] = JSONDictField
def create(self, validated_data):
instance = super().create(validated_data)
@ -92,6 +71,21 @@ class ModelSerializer(BaseModelSerializer):
return instance
class JSONDictField(JSONField):
"""JSON Field which only allows dictionaries"""
default_validators = [is_dict]
class JSONExtension(OpenApiSerializerFieldExtension):
"""Generate API Schema for JSON fields as"""
target_class = "authentik.core.api.utils.JSONDictField"
def map_serializer_field(self, auto_schema, direction):
return build_basic_type(OpenApiTypes.OBJECT)
class PassiveSerializer(Serializer):
"""Base serializer class which doesn't implement create/update methods"""

View File

@ -13,6 +13,7 @@ class Command(TenantCommand):
parser.add_argument("usernames", nargs="*", type=str)
def handle_per_tenant(self, **options):
print(options)
new_type = UserTypes(options["type"])
qs = (
User.objects.exclude_anonymous()

View File

@ -18,7 +18,7 @@ from django.http import HttpRequest
from django.utils.functional import SimpleLazyObject, cached_property
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from django_cte import CTE, with_cte
from django_cte import CTEQuerySet, With
from guardian.conf import settings
from guardian.mixins import GuardianUserMixin
from model_utils.managers import InheritanceManager
@ -136,7 +136,7 @@ class AttributesMixin(models.Model):
return instance, False
class GroupQuerySet(QuerySet):
class GroupQuerySet(CTEQuerySet):
def with_children_recursive(self):
"""Recursively get all groups that have the current queryset as parents
or are indirectly related."""
@ -165,9 +165,9 @@ class GroupQuerySet(QuerySet):
)
# Build the recursive query, see above
cte = CTE.recursive(make_cte)
cte = With.recursive(make_cte)
# Return the result, as a usable queryset for Group.
return with_cte(cte, select=cte.join(Group, group_uuid=cte.col.group_uuid))
return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte)
class Group(SerializerModel, AttributesMixin):

View File

@ -114,7 +114,6 @@ class TestApplicationsAPI(APITestCase):
self.assertJSONEqual(
response.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,
@ -168,7 +167,6 @@ class TestApplicationsAPI(APITestCase):
self.assertJSONEqual(
response.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,

View File

@ -119,17 +119,17 @@ class TestTrimPasswordHistory(TestCase):
[
UserPasswordHistory(
user=self.user,
old_password="hunter1", # nosec
old_password="hunter1", # nosec B106
created_at=_now - timedelta(days=3),
),
UserPasswordHistory(
user=self.user,
old_password="hunter2", # nosec
old_password="hunter2", # nosec B106
created_at=_now - timedelta(days=2),
),
UserPasswordHistory(
user=self.user,
old_password="hunter3", # nosec
old_password="hunter3", # nosec B106
created_at=_now,
),
]

View File

@ -1,8 +1,10 @@
from hashlib import sha256
from django.contrib.auth.signals import user_logged_out
from django.db.models import Model
from django.db.models.signals import post_delete, post_save, pre_delete
from django.dispatch import receiver
from django.http.request import HttpRequest
from guardian.shortcuts import assign_perm
from authentik.core.models import (
@ -60,6 +62,31 @@ def ssf_providers_post_save(sender: type[Model], instance: SSFProvider, created:
instance.save()
@receiver(user_logged_out)
def ssf_user_logged_out_session_revoked(sender, request: HttpRequest, user: User, **_):
"""Session revoked trigger (user logged out)"""
if not request.session or not request.session.session_key or not user:
return
send_ssf_event(
EventTypes.CAEP_SESSION_REVOKED,
{
"initiating_entity": "user",
},
sub_id={
"format": "complex",
"session": {
"format": "opaque",
"id": sha256(request.session.session_key.encode("ascii")).hexdigest(),
},
"user": {
"format": "email",
"email": user.email,
},
},
request=request,
)
@receiver(pre_delete, sender=AuthenticatedSession)
def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSession, **_):
"""Session revoked trigger (users' session has been deleted)

View File

@ -1,12 +0,0 @@
"""Enterprise app config"""
from authentik.enterprise.apps import EnterpriseConfig
class AuthentikEnterpriseSearchConfig(EnterpriseConfig):
"""Enterprise app config"""
name = "authentik.enterprise.search"
label = "authentik_search"
verbose_name = "authentik Enterprise.Search"
default = True

View File

@ -1,128 +0,0 @@
"""DjangoQL search"""
from collections import OrderedDict, defaultdict
from collections.abc import Generator
from django.db import connection
from django.db.models import Model, Q
from djangoql.compat import text_type
from djangoql.schema import StrField
class JSONSearchField(StrField):
"""JSON field for DjangoQL"""
model: Model
def __init__(self, model=None, name=None, nullable=None, suggest_nested=True):
# Set this in the constructor to not clobber the type variable
self.type = "relation"
self.suggest_nested = suggest_nested
super().__init__(model, name, nullable)
def get_lookup(self, path, operator, value):
search = "__".join(path)
op, invert = self.get_operator(operator)
q = Q(**{f"{search}{op}": self.get_lookup_value(value)})
return ~q if invert else q
def json_field_keys(self) -> Generator[tuple[str]]:
with connection.cursor() as cursor:
cursor.execute(
f"""
WITH RECURSIVE "{self.name}_keys" AS (
SELECT
ARRAY[jsonb_object_keys("{self.name}")] AS key_path_array,
"{self.name}" -> jsonb_object_keys("{self.name}") AS value
FROM {self.model._meta.db_table}
WHERE "{self.name}" IS NOT NULL
AND jsonb_typeof("{self.name}") = 'object'
UNION ALL
SELECT
ck.key_path_array || jsonb_object_keys(ck.value),
ck.value -> jsonb_object_keys(ck.value) AS value
FROM "{self.name}_keys" ck
WHERE jsonb_typeof(ck.value) = 'object'
),
unique_paths AS (
SELECT DISTINCT key_path_array
FROM "{self.name}_keys"
)
SELECT key_path_array FROM unique_paths;
""" # nosec
)
return (x[0] for x in cursor.fetchall())
def get_nested_options(self) -> OrderedDict:
"""Get keys of all nested objects to show autocomplete"""
if not self.suggest_nested:
return OrderedDict()
base_model_name = f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}"
def recursive_function(parts: list[str], parent_parts: list[str] | None = None):
if not parent_parts:
parent_parts = []
path = parts.pop(0)
parent_parts.append(path)
relation_key = "_".join(parent_parts)
if len(parts) > 1:
out_dict = {
relation_key: {
parts[0]: {
"type": "relation",
"relation": f"{relation_key}_{parts[0]}",
}
}
}
child_paths = recursive_function(parts.copy(), parent_parts.copy())
child_paths.update(out_dict)
return child_paths
else:
return {relation_key: {parts[0]: {}}}
relation_structure = defaultdict(dict)
for relations in self.json_field_keys():
result = recursive_function([base_model_name] + relations)
for relation_key, value in result.items():
for sub_relation_key, sub_value in value.items():
if not relation_structure[relation_key].get(sub_relation_key, None):
relation_structure[relation_key][sub_relation_key] = sub_value
else:
relation_structure[relation_key][sub_relation_key].update(sub_value)
final_dict = defaultdict(dict)
for key, value in relation_structure.items():
for sub_key, sub_value in value.items():
if not sub_value:
final_dict[key][sub_key] = {
"type": "str",
"nullable": True,
}
else:
final_dict[key][sub_key] = sub_value
return OrderedDict(final_dict)
def relation(self) -> str:
return f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}"
class ChoiceSearchField(StrField):
def __init__(self, model=None, name=None, nullable=None):
super().__init__(model, name, nullable, suggest_options=True)
def get_options(self, search):
result = []
choices = self._field_choices()
if choices:
search = search.lower()
for c in choices:
choice = text_type(c[0])
if search in choice.lower():
result.append(choice)
return result

View File

@ -1,53 +0,0 @@
from rest_framework.response import Response
from authentik.api.pagination import Pagination
from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, QLSearch
class AutocompletePagination(Pagination):
def paginate_queryset(self, queryset, request, view=None):
self.view = view
return super().paginate_queryset(queryset, request, view)
def get_autocomplete(self):
schema = QLSearch().get_schema(self.request, self.view)
introspections = {}
if hasattr(self.view, "get_ql_fields"):
from authentik.enterprise.search.schema import AKQLSchemaSerializer
introspections = AKQLSchemaSerializer().serialize(
schema(self.page.paginator.object_list.model)
)
return introspections
def get_paginated_response(self, data):
previous_page_number = 0
if self.page.has_previous():
previous_page_number = self.page.previous_page_number()
next_page_number = 0
if self.page.has_next():
next_page_number = self.page.next_page_number()
return Response(
{
"pagination": {
"next": next_page_number,
"previous": previous_page_number,
"count": self.page.paginator.count,
"current": self.page.number,
"total_pages": self.page.paginator.num_pages,
"start_index": self.page.start_index(),
"end_index": self.page.end_index(),
},
"results": data,
"autocomplete": self.get_autocomplete(),
}
)
def get_paginated_response_schema(self, schema):
final_schema = super().get_paginated_response_schema(schema)
final_schema["properties"]["autocomplete"] = {
"$ref": f"#/components/schemas/{AUTOCOMPLETE_COMPONENT_NAME}"
}
final_schema["required"].append("autocomplete")
return final_schema

View File

@ -1,78 +0,0 @@
"""DjangoQL search"""
from django.apps import apps
from django.db.models import QuerySet
from djangoql.ast import Name
from djangoql.exceptions import DjangoQLError
from djangoql.queryset import apply_search
from djangoql.schema import DjangoQLSchema
from rest_framework.filters import SearchFilter
from rest_framework.request import Request
from structlog.stdlib import get_logger
from authentik.enterprise.search.fields import JSONSearchField
LOGGER = get_logger()
AUTOCOMPLETE_COMPONENT_NAME = "Autocomplete"
AUTOCOMPLETE_SCHEMA = {
"type": "object",
"additionalProperties": {},
}
class BaseSchema(DjangoQLSchema):
"""Base Schema which deals with JSON Fields"""
def resolve_name(self, name: Name):
model = self.model_label(self.current_model)
root_field = name.parts[0]
field = self.models[model].get(root_field)
# If the query goes into a JSON field, return the root
# field as the JSON field will do the rest
if isinstance(field, JSONSearchField):
# This is a workaround; build_filter will remove the right-most
# entry in the path as that is intended to be the same as the field
# however for JSON that is not the case
if name.parts[-1] != root_field:
name.parts.append(root_field)
return field
return super().resolve_name(name)
class QLSearch(SearchFilter):
"""rest_framework search filter which uses DjangoQL"""
@property
def enabled(self):
return apps.get_app_config("authentik_enterprise").enabled()
def get_search_terms(self, request) -> str:
"""
Search terms are set by a ?search=... query parameter,
and may be comma and/or whitespace delimited.
"""
params = request.query_params.get(self.search_param, "")
params = params.replace("\x00", "") # strip null characters
return params
def get_schema(self, request: Request, view) -> BaseSchema:
ql_fields = []
if hasattr(view, "get_ql_fields"):
ql_fields = view.get_ql_fields()
class InlineSchema(BaseSchema):
def get_fields(self, model):
return ql_fields or []
return InlineSchema
def filter_queryset(self, request: Request, queryset: QuerySet, view) -> QuerySet:
search_query = self.get_search_terms(request)
schema = self.get_schema(request, view)
if len(search_query) == 0 or not self.enabled:
return super().filter_queryset(request, queryset, view)
try:
return apply_search(queryset, search_query, schema=schema)
except DjangoQLError as exc:
LOGGER.debug("Failed to parse search expression", exc=exc)
return super().filter_queryset(request, queryset, view)

View File

@ -1,29 +0,0 @@
from djangoql.serializers import DjangoQLSchemaSerializer
from drf_spectacular.generators import SchemaGenerator
from authentik.api.schema import create_component
from authentik.enterprise.search.fields import JSONSearchField
from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA
class AKQLSchemaSerializer(DjangoQLSchemaSerializer):
def serialize(self, schema):
serialization = super().serialize(schema)
for _, fields in schema.models.items():
for _, field in fields.items():
if not isinstance(field, JSONSearchField):
continue
serialization["models"].update(field.get_nested_options())
return serialization
def serialize_field(self, field):
result = super().serialize_field(field)
if isinstance(field, JSONSearchField):
result["relation"] = field.relation()
return result
def postprocess_schema_search_autocomplete(result, generator: SchemaGenerator, **kwargs):
create_component(generator, AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA)
return result

View File

@ -1,17 +0,0 @@
SPECTACULAR_SETTINGS = {
"POSTPROCESSING_HOOKS": [
"authentik.api.schema.postprocess_schema_responses",
"authentik.enterprise.search.schema.postprocess_schema_search_autocomplete",
"drf_spectacular.hooks.postprocess_schema_enums",
],
}
REST_FRAMEWORK = {
"DEFAULT_PAGINATION_CLASS": "authentik.enterprise.search.pagination.AutocompletePagination",
"DEFAULT_FILTER_BACKENDS": [
"authentik.enterprise.search.ql.QLSearch",
"authentik.rbac.filters.ObjectFilter",
"django_filters.rest_framework.DjangoFilterBackend",
"rest_framework.filters.OrderingFilter",
],
}

View File

@ -1,78 +0,0 @@
from json import loads
from unittest.mock import PropertyMock, patch
from urllib.parse import urlencode
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.tests.utils import create_test_admin_user
@patch(
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
PropertyMock(return_value=True),
)
class QLTest(APITestCase):
def setUp(self):
self.user = create_test_admin_user()
# ensure we have more than 1 user
create_test_admin_user()
def test_search(self):
"""Test simple search query"""
self.client.force_login(self.user)
query = f'username = "{self.user.username}"'
res = self.client.get(
reverse(
"authentik_api:user-list",
)
+ f"?{urlencode({"search": query})}"
)
self.assertEqual(res.status_code, 200)
content = loads(res.content)
self.assertEqual(content["pagination"]["count"], 1)
self.assertEqual(content["results"][0]["username"], self.user.username)
def test_no_search(self):
"""Ensure works with no search query"""
self.client.force_login(self.user)
res = self.client.get(
reverse(
"authentik_api:user-list",
)
)
self.assertEqual(res.status_code, 200)
content = loads(res.content)
self.assertNotEqual(content["pagination"]["count"], 1)
def test_search_no_ql(self):
"""Test simple search query (no QL)"""
self.client.force_login(self.user)
res = self.client.get(
reverse(
"authentik_api:user-list",
)
+ f"?{urlencode({"search": self.user.username})}"
)
self.assertEqual(res.status_code, 200)
content = loads(res.content)
self.assertGreaterEqual(content["pagination"]["count"], 1)
self.assertEqual(content["results"][0]["username"], self.user.username)
def test_search_json(self):
"""Test search query with a JSON attribute"""
self.user.attributes = {"foo": {"bar": "baz"}}
self.user.save()
self.client.force_login(self.user)
query = 'attributes.foo.bar = "baz"'
res = self.client.get(
reverse(
"authentik_api:user-list",
)
+ f"?{urlencode({"search": query})}"
)
self.assertEqual(res.status_code, 200)
content = loads(res.content)
self.assertEqual(content["pagination"]["count"], 1)
self.assertEqual(content["results"][0]["username"], self.user.username)

View File

@ -18,7 +18,6 @@ TENANT_APPS = [
"authentik.enterprise.providers.google_workspace",
"authentik.enterprise.providers.microsoft_entra",
"authentik.enterprise.providers.ssf",
"authentik.enterprise.search",
"authentik.enterprise.stages.authenticator_endpoint_gdtc",
"authentik.enterprise.stages.mtls",
"authentik.enterprise.stages.source",

View File

@ -97,7 +97,6 @@ class SourceStageFinal(StageView):
token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN)
self.logger.info("Replacing source flow with overridden flow", flow=token.flow.slug)
plan = token.plan
plan.context.update(self.executor.plan.context)
plan.context[PLAN_CONTEXT_IS_RESTORED] = token
response = plan.to_redirect(self.request, token.flow)
token.delete()

View File

@ -90,17 +90,14 @@ class TestSourceStage(FlowTestCase):
plan: FlowPlan = session[SESSION_KEY_PLAN]
plan.insert_stage(in_memory_stage(SourceStageFinal), index=0)
plan.context[PLAN_CONTEXT_IS_RESTORED] = flow_token
plan.context["foo"] = "bar"
session[SESSION_KEY_PLAN] = plan
session.save()
# Pretend we've just returned from the source
with self.assertFlowFinishes() as ff:
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True
)
self.assertEqual(response.status_code, 200)
self.assertStageRedirects(
response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
)
self.assertEqual(ff().context["foo"], "bar")
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True
)
self.assertEqual(response.status_code, 200)
self.assertStageRedirects(
response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
)

View File

@ -132,22 +132,6 @@ class EventViewSet(ModelViewSet):
]
filterset_class = EventsFilter
def get_ql_fields(self):
from djangoql.schema import DateTimeField, StrField
from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
return [
ChoiceSearchField(Event, "action"),
StrField(Event, "event_uuid"),
StrField(Event, "app", suggest_options=True),
StrField(Event, "client_ip"),
JSONSearchField(Event, "user", suggest_nested=False),
JSONSearchField(Event, "brand", suggest_nested=False),
JSONSearchField(Event, "context", suggest_nested=False),
DateTimeField(Event, "created", suggest_options=True),
]
@extend_schema(
methods=["GET"],
responses={200: EventTopPerUserSerializer(many=True)},

View File

@ -11,7 +11,7 @@ from authentik.events.models import NotificationRule
class NotificationRuleSerializer(ModelSerializer):
"""NotificationRule Serializer"""
destination_group_obj = GroupSerializer(read_only=True, source="destination_group")
group_obj = GroupSerializer(read_only=True, source="group")
class Meta:
model = NotificationRule
@ -20,9 +20,8 @@ class NotificationRuleSerializer(ModelSerializer):
"name",
"transports",
"severity",
"destination_group",
"destination_group_obj",
"destination_event_user",
"group",
"group_obj",
]
@ -31,6 +30,6 @@ class NotificationRuleViewSet(UsedByMixin, ModelViewSet):
queryset = NotificationRule.objects.all()
serializer_class = NotificationRuleSerializer
filterset_fields = ["name", "severity", "destination_group__name"]
filterset_fields = ["name", "severity", "group__name"]
ordering = ["name"]
search_fields = ["name", "destination_group__name"]
search_fields = ["name", "group__name"]

View File

@ -15,13 +15,13 @@ class MMDBContextProcessor(EventContextProcessor):
self.reader: Reader | None = None
self._last_mtime: float = 0.0
self.logger = get_logger()
self.load()
self.open()
def path(self) -> str | None:
"""Get the path to the MMDB file to load"""
raise NotImplementedError
def load(self):
def open(self):
"""Get GeoIP Reader, if configured, otherwise none"""
path = self.path()
if path == "" or not path:
@ -44,7 +44,7 @@ class MMDBContextProcessor(EventContextProcessor):
diff = self._last_mtime < mtime
if diff > 0:
self.logger.info("Found new MMDB Database, reopening", diff=diff, path=path)
self.load()
self.open()
except OSError as exc:
self.logger.warning("Failed to check MMDB age", exc=exc)

View File

@ -19,7 +19,7 @@ from authentik.blueprints.v1.importer import excluded_models
from authentik.core.models import Group, User
from authentik.events.models import Event, EventAction, Notification
from authentik.events.utils import model_to_dict
from authentik.lib.sentry import should_ignore_exception
from authentik.lib.sentry import before_send
from authentik.lib.utils.errors import exception_to_string
from authentik.stages.authenticator_static.models import StaticToken
@ -173,7 +173,7 @@ class AuditMiddleware:
message=exception_to_string(exception),
)
thread.run()
elif not should_ignore_exception(exception):
elif before_send({}, {"exc_info": (None, exception, None)}) is not None:
thread = EventNewThread(
EventAction.SYSTEM_EXCEPTION,
request,

View File

@ -1,26 +0,0 @@
# Generated by Django 5.1.11 on 2025-06-16 23:21
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_events", "0009_remove_notificationtransport_webhook_mapping_and_more"),
]
operations = [
migrations.RenameField(
model_name="notificationrule",
old_name="group",
new_name="destination_group",
),
migrations.AddField(
model_name="notificationrule",
name="destination_event_user",
field=models.BooleanField(
default=False,
help_text="When enabled, notification will be sent to user the user that triggered the event.When destination_group is configured, notification is sent to both.",
),
),
]

View File

@ -1,12 +1,10 @@
"""authentik events models"""
from collections.abc import Generator
from datetime import timedelta
from difflib import get_close_matches
from functools import lru_cache
from inspect import currentframe
from smtplib import SMTPException
from typing import Any
from uuid import uuid4
from django.apps import apps
@ -193,32 +191,17 @@ class Event(SerializerModel, ExpiringModel):
brand: Brand = request.brand
self.brand = sanitize_dict(model_to_dict(brand))
if hasattr(request, "user"):
self.user = get_user(request.user)
original_user = None
if hasattr(request, "session"):
original_user = request.session.get(SESSION_KEY_IMPERSONATE_ORIGINAL_USER, None)
self.user = get_user(request.user, original_user)
if user:
self.user = get_user(user)
# Check if we're currently impersonating, and add that user
if hasattr(request, "session"):
from authentik.flows.views.executor import SESSION_KEY_PLAN
# Check if we're currently impersonating, and add that user
if SESSION_KEY_IMPERSONATE_ORIGINAL_USER in request.session:
self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER])
self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER])
# Special case for events that happen during a flow, the user might not be authenticated
# yet but is a pending user instead
if SESSION_KEY_PLAN in request.session:
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
plan: FlowPlan = request.session[SESSION_KEY_PLAN]
pending_user = plan.context.get(PLAN_CONTEXT_PENDING_USER, None)
# Only save `authenticated_as` if there's a different pending user in the flow
# than the user that is authenticated
if pending_user and (
(pending_user.pk and pending_user.pk != self.user.get("pk"))
or (not pending_user.pk)
):
orig_user = self.user.copy()
self.user = {"authenticated_as": orig_user, **get_user(pending_user)}
# User 255.255.255.255 as fallback if IP cannot be determined
self.client_ip = ClientIPMiddleware.get_client_ip(request)
# Enrich event data
@ -564,7 +547,7 @@ class NotificationRule(SerializerModel, PolicyBindingModel):
default=NotificationSeverity.NOTICE,
help_text=_("Controls which severity level the created notifications will have."),
)
destination_group = models.ForeignKey(
group = models.ForeignKey(
Group,
help_text=_(
"Define which group of users this notification should be sent and shown to. "
@ -574,19 +557,6 @@ class NotificationRule(SerializerModel, PolicyBindingModel):
blank=True,
on_delete=models.SET_NULL,
)
destination_event_user = models.BooleanField(
default=False,
help_text=_(
"When enabled, notification will be sent to user the user that triggered the event."
"When destination_group is configured, notification is sent to both."
),
)
def destination_users(self, event: Event) -> Generator[User, Any]:
if self.destination_event_user and event.user.get("pk"):
yield User(pk=event.user.get("pk"))
if self.destination_group:
yield from self.destination_group.users.all()
@property
def serializer(self) -> type[Serializer]:

View File

@ -68,10 +68,14 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
if not result.passing:
return
if not trigger.group:
LOGGER.debug("e(trigger): trigger has no group", trigger=trigger)
return
LOGGER.debug("e(trigger): event trigger matched", trigger=trigger)
# Create the notification objects
for transport in trigger.transports.all():
for user in trigger.destination_users(event):
for user in trigger.group.users.all():
LOGGER.debug("created notification")
notification_transport.apply_async(
args=[

View File

@ -2,9 +2,7 @@
from django.test import TestCase
from authentik.events.context_processors.base import get_context_processors
from authentik.events.context_processors.geoip import GeoIPContextProcessor
from authentik.events.models import Event, EventAction
class TestGeoIP(TestCase):
@ -15,7 +13,8 @@ class TestGeoIP(TestCase):
def test_simple(self):
"""Test simple city wrapper"""
# IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
# IPs from
# https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
self.assertEqual(
self.reader.city_dict("2.125.160.216"),
{
@ -26,12 +25,3 @@ class TestGeoIP(TestCase):
"long": -1.25,
},
)
def test_special_chars(self):
"""Test city name with special characters"""
# IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
event = Event.new(EventAction.LOGIN)
event.client_ip = "89.160.20.112"
for processor in get_context_processors():
processor.enrich_event(event)
event.save()

View File

@ -8,11 +8,9 @@ from django.views.debug import SafeExceptionReporterFilter
from guardian.shortcuts import get_anonymous_user
from authentik.brands.models import Brand
from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_user
from authentik.core.models import Group
from authentik.events.models import Event
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
from authentik.flows.views.executor import QS_QUERY, SESSION_KEY_PLAN
from authentik.flows.views.executor import QS_QUERY
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy
@ -118,92 +116,3 @@ class TestEvents(TestCase):
"pk": brand.pk.hex,
},
)
def test_from_http_flow_pending_user(self):
"""Test request from flow request with a pending user"""
user = create_test_user()
session = self.client.session
plan = FlowPlan(generate_id())
plan.context[PLAN_CONTEXT_PENDING_USER] = user
session[SESSION_KEY_PLAN] = plan
session.save()
request = self.factory.get("/")
request.session = session
request.user = user
event = Event.new("unittest").from_http(request)
self.assertEqual(
event.user,
{
"email": user.email,
"pk": user.pk,
"username": user.username,
},
)
def test_from_http_flow_pending_user_anon(self):
"""Test request from flow request with a pending user"""
user = create_test_user()
anon = get_anonymous_user()
session = self.client.session
plan = FlowPlan(generate_id())
plan.context[PLAN_CONTEXT_PENDING_USER] = user
session[SESSION_KEY_PLAN] = plan
session.save()
request = self.factory.get("/")
request.session = session
request.user = anon
event = Event.new("unittest").from_http(request)
self.assertEqual(
event.user,
{
"authenticated_as": {
"pk": anon.pk,
"is_anonymous": True,
"username": "AnonymousUser",
"email": "",
},
"email": user.email,
"pk": user.pk,
"username": user.username,
},
)
def test_from_http_flow_pending_user_fake(self):
"""Test request from flow request with a pending user"""
user = User(
username=generate_id(),
email=generate_id(),
)
anon = get_anonymous_user()
session = self.client.session
plan = FlowPlan(generate_id())
plan.context[PLAN_CONTEXT_PENDING_USER] = user
session[SESSION_KEY_PLAN] = plan
session.save()
request = self.factory.get("/")
request.session = session
request.user = anon
event = Event.new("unittest").from_http(request)
self.assertEqual(
event.user,
{
"authenticated_as": {
"pk": anon.pk,
"is_anonymous": True,
"username": "AnonymousUser",
"email": "",
},
"email": user.email,
"pk": user.pk,
"username": user.username,
},
)

View File

@ -6,7 +6,6 @@ from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_user
from authentik.events.models import (
Event,
EventAction,
@ -35,7 +34,7 @@ class TestEventsNotifications(APITestCase):
def test_trigger_empty(self):
"""Test trigger without any policies attached"""
transport = NotificationTransport.objects.create(name=generate_id())
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport)
trigger.save()
@ -47,7 +46,7 @@ class TestEventsNotifications(APITestCase):
def test_trigger_single(self):
"""Test simple transport triggering"""
transport = NotificationTransport.objects.create(name=generate_id())
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport)
trigger.save()
matcher = EventMatcherPolicy.objects.create(
@ -60,25 +59,6 @@ class TestEventsNotifications(APITestCase):
Event.new(EventAction.CUSTOM_PREFIX).save()
self.assertEqual(execute_mock.call_count, 1)
def test_trigger_event_user(self):
"""Test trigger with event user"""
user = create_test_user()
transport = NotificationTransport.objects.create(name=generate_id())
trigger = NotificationRule.objects.create(name=generate_id(), destination_event_user=True)
trigger.transports.add(transport)
trigger.save()
matcher = EventMatcherPolicy.objects.create(
name="matcher", action=EventAction.CUSTOM_PREFIX
)
PolicyBinding.objects.create(target=trigger, policy=matcher, order=0)
execute_mock = MagicMock()
with patch("authentik.events.models.NotificationTransport.send", execute_mock):
Event.new(EventAction.CUSTOM_PREFIX).set_user(user).save()
self.assertEqual(execute_mock.call_count, 1)
notification: Notification = execute_mock.call_args[0][0]
self.assertEqual(notification.user, user)
def test_trigger_no_group(self):
"""Test trigger without group"""
trigger = NotificationRule.objects.create(name=generate_id())
@ -96,7 +76,7 @@ class TestEventsNotifications(APITestCase):
"""Test Policy error which would cause recursion"""
transport = NotificationTransport.objects.create(name=generate_id())
NotificationRule.objects.filter(name__startswith="default").delete()
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport)
trigger.save()
matcher = EventMatcherPolicy.objects.create(
@ -119,7 +99,7 @@ class TestEventsNotifications(APITestCase):
transport = NotificationTransport.objects.create(name=generate_id(), send_once=True)
NotificationRule.objects.filter(name__startswith="default").delete()
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport)
trigger.save()
matcher = EventMatcherPolicy.objects.create(
@ -143,7 +123,7 @@ class TestEventsNotifications(APITestCase):
name=generate_id(), webhook_mapping_body=mapping, mode=TransportMode.LOCAL
)
NotificationRule.objects.filter(name__startswith="default").delete()
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport)
matcher = EventMatcherPolicy.objects.create(
name="matcher", action=EventAction.CUSTOM_PREFIX

View File

@ -74,8 +74,8 @@ def model_to_dict(model: Model) -> dict[str, Any]:
}
def get_user(user: User | AnonymousUser) -> dict[str, Any]:
"""Convert user object to dictionary"""
def get_user(user: User | AnonymousUser, original_user: User | None = None) -> dict[str, Any]:
"""Convert user object to dictionary, optionally including the original user"""
if isinstance(user, AnonymousUser):
try:
user = get_anonymous_user()
@ -88,6 +88,10 @@ def get_user(user: User | AnonymousUser) -> dict[str, Any]:
}
if user.username == settings.ANONYMOUS_USER_NAME:
user_data["is_anonymous"] = True
if original_user:
original_data = get_user(original_user)
original_data["on_behalf_of"] = user_data
return original_data
return user_data

View File

@ -4,10 +4,8 @@ from unittest.mock import MagicMock, PropertyMock, patch
from urllib.parse import urlencode
from django.http import HttpRequest, HttpResponse
from django.test import override_settings
from django.test.client import RequestFactory
from django.urls import reverse
from rest_framework.exceptions import ParseError
from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_flow, create_test_user
@ -650,25 +648,3 @@ class TestFlowExecutor(FlowTestCase):
self.assertStageResponse(response, flow, component="ak-stage-identification")
response = self.client.post(exec_url, {"uid_field": user_other.username}, follow=True)
self.assertStageResponse(response, flow, component="ak-stage-access-denied")
@patch(
"authentik.flows.views.executor.to_stage_response",
TO_STAGE_RESPONSE_MOCK,
)
def test_invalid_json(self):
"""Test invalid JSON body"""
flow = create_test_flow()
FlowStageBinding.objects.create(
target=flow, stage=DummyStage.objects.create(name=generate_id()), order=0
)
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
with override_settings(TEST=False, DEBUG=False):
self.client.logout()
response = self.client.post(url, data="{", content_type="application/json")
self.assertEqual(response.status_code, 200)
with self.assertRaises(ParseError):
self.client.logout()
response = self.client.post(url, data="{", content_type="application/json")
self.assertEqual(response.status_code, 200)

View File

@ -55,7 +55,7 @@ from authentik.flows.planner import (
FlowPlanner,
)
from authentik.flows.stage import AccessDeniedStage, StageView
from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception
from authentik.lib.sentry import SentryIgnoredException
from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.reflection import all_subclasses, class_to_path
from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs
@ -234,13 +234,12 @@ class FlowExecutorView(APIView):
"""Handle exception in stage execution"""
if settings.DEBUG or settings.TEST:
raise exc
capture_exception(exc)
self._logger.warning(exc)
if not should_ignore_exception(exc):
capture_exception(exc)
Event.new(
action=EventAction.SYSTEM_EXCEPTION,
message=exception_to_string(exc),
).from_http(self.request)
Event.new(
action=EventAction.SYSTEM_EXCEPTION,
message=exception_to_string(exc),
).from_http(self.request)
challenge = FlowErrorChallenge(self.request, exc)
challenge.is_valid(raise_exception=True)
return to_stage_response(self.request, HttpChallengeResponse(challenge))

View File

@ -14,7 +14,6 @@ from django_redis.exceptions import ConnectionInterrupted
from docker.errors import DockerException
from h11 import LocalProtocolError
from ldap3.core.exceptions import LDAPException
from psycopg.errors import Error
from redis.exceptions import ConnectionError as RedisConnectionError
from redis.exceptions import RedisError, ResponseError
from rest_framework.exceptions import APIException
@ -45,49 +44,6 @@ class SentryIgnoredException(Exception):
"""Base Class for all errors that are suppressed, and not sent to sentry."""
ignored_classes = (
# Inbuilt types
KeyboardInterrupt,
ConnectionResetError,
OSError,
PermissionError,
# Django Errors
Error,
ImproperlyConfigured,
DatabaseError,
OperationalError,
InternalError,
ProgrammingError,
SuspiciousOperation,
ValidationError,
# Redis errors
RedisConnectionError,
ConnectionInterrupted,
RedisError,
ResponseError,
# websocket errors
ChannelFull,
WebSocketException,
LocalProtocolError,
# rest_framework error
APIException,
# celery errors
WorkerLostError,
CeleryError,
SoftTimeLimitExceeded,
# custom baseclass
SentryIgnoredException,
# ldap errors
LDAPException,
# Docker errors
DockerException,
# End-user errors
Http404,
# AsyncIO
CancelledError,
)
class SentryTransport(HttpTransport):
"""Custom sentry transport with custom user-agent"""
@ -145,17 +101,56 @@ def traces_sampler(sampling_context: dict) -> float:
return float(CONFIG.get("error_reporting.sample_rate", 0.1))
def should_ignore_exception(exc: Exception) -> bool:
"""Check if an exception should be dropped"""
return isinstance(exc, ignored_classes)
def before_send(event: dict, hint: dict) -> dict | None:
"""Check if error is database error, and ignore if so"""
from psycopg.errors import Error
ignored_classes = (
# Inbuilt types
KeyboardInterrupt,
ConnectionResetError,
OSError,
PermissionError,
# Django Errors
Error,
ImproperlyConfigured,
DatabaseError,
OperationalError,
InternalError,
ProgrammingError,
SuspiciousOperation,
ValidationError,
# Redis errors
RedisConnectionError,
ConnectionInterrupted,
RedisError,
ResponseError,
# websocket errors
ChannelFull,
WebSocketException,
LocalProtocolError,
# rest_framework error
APIException,
# celery errors
WorkerLostError,
CeleryError,
SoftTimeLimitExceeded,
# custom baseclass
SentryIgnoredException,
# ldap errors
LDAPException,
# Docker errors
DockerException,
# End-user errors
Http404,
# AsyncIO
CancelledError,
)
exc_value = None
if "exc_info" in hint:
_, exc_value, _ = hint["exc_info"]
if should_ignore_exception(exc_value):
if isinstance(exc_value, ignored_classes):
LOGGER.debug("dropping exception", exc=exc_value)
return None
if "logger" in event:

View File

@ -2,7 +2,7 @@
from django.test import TestCase
from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception
from authentik.lib.sentry import SentryIgnoredException, before_send
class TestSentry(TestCase):
@ -10,8 +10,8 @@ class TestSentry(TestCase):
def test_error_not_sent(self):
"""Test SentryIgnoredError not sent"""
self.assertTrue(should_ignore_exception(SentryIgnoredException()))
self.assertIsNone(before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)}))
def test_error_sent(self):
"""Test error sent"""
self.assertFalse(should_ignore_exception(ValueError()))
self.assertEqual({}, before_send({}, {"exc_info": (0, ValueError(), 0)}))

View File

@ -1,13 +1,15 @@
"""authentik outpost signals"""
from django.contrib.auth.signals import user_logged_out
from django.core.cache import cache
from django.db.models import Model
from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save
from django.dispatch import receiver
from django.http import HttpRequest
from structlog.stdlib import get_logger
from authentik.brands.models import Brand
from authentik.core.models import AuthenticatedSession, Provider
from authentik.core.models import AuthenticatedSession, Provider, User
from authentik.crypto.models import CertificateKeyPair
from authentik.lib.utils.reflection import class_to_path
from authentik.outposts.models import Outpost, OutpostServiceConnection
@ -80,6 +82,14 @@ def pre_delete_cleanup(sender, instance: Outpost, **_):
outpost_controller.delay(instance.pk.hex, action="down", from_cache=True)
@receiver(user_logged_out)
def logout_revoke_direct(sender: type[User], request: HttpRequest, **_):
"""Catch logout by direct logout and forward to providers"""
if not request.session or not request.session.session_key:
return
outpost_session_end.delay(request.session.session_key)
@receiver(pre_delete, sender=AuthenticatedSession)
def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
"""Catch logout by expiring sessions being deleted"""

View File

@ -1,9 +1,11 @@
"""Websocket tests"""
from dataclasses import asdict
from unittest.mock import patch
from channels.routing import URLRouter
from channels.testing import WebsocketCommunicator
from django.contrib.contenttypes.models import ContentType
from django.test import TransactionTestCase
from authentik import __version__
@ -14,6 +16,12 @@ from authentik.providers.proxy.models import ProxyProvider
from authentik.root import websocket
def patched__get_ct_cached(app_label, codename):
"""Caches `ContentType` instances like its `QuerySet` does."""
return ContentType.objects.get(app_label=app_label, permission__codename=codename)
@patch("guardian.shortcuts._get_ct_cached", patched__get_ct_cached)
class TestOutpostWS(TransactionTestCase):
"""Websocket tests"""

View File

@ -1,10 +1,23 @@
from django.contrib.auth.signals import user_logged_out
from django.db.models.signals import post_save, pre_delete
from django.dispatch import receiver
from django.http import HttpRequest
from authentik.core.models import AuthenticatedSession, User
from authentik.providers.oauth2.models import AccessToken, DeviceToken, RefreshToken
@receiver(user_logged_out)
def user_logged_out_oauth_tokens_removal(sender, request: HttpRequest, user: User, **_):
"""Revoke tokens upon user logout"""
if not request.session or not request.session.session_key:
return
AccessToken.objects.filter(
user=user,
session__session__session_key=request.session.session_key,
).delete()
@receiver(pre_delete, sender=AuthenticatedSession)
def user_session_deleted_oauth_tokens_removal(sender, instance: AuthenticatedSession, **_):
"""Revoke tokens upon user logout"""

View File

@ -387,7 +387,8 @@ class TestAuthorize(OAuthTestCase):
self.assertEqual(
response.url,
(
f"http://localhost#id_token={provider.encode(token.id_token.to_dict())}"
f"http://localhost#access_token={token.token}"
f"&id_token={provider.encode(token.id_token.to_dict())}"
f"&token_type={TOKEN_TYPE}"
f"&expires_in={int(expires)}&state={state}"
),
@ -562,6 +563,7 @@ class TestAuthorize(OAuthTestCase):
"url": "http://localhost",
"title": f"Redirecting to {app.name}...",
"attrs": {
"access_token": token.token,
"id_token": provider.encode(token.id_token.to_dict()),
"token_type": TOKEN_TYPE,
"expires_in": "3600",

View File

@ -150,12 +150,12 @@ class OAuthAuthorizationParams:
self.check_redirect_uri()
self.check_grant()
self.check_scope(github_compat)
self.check_nonce()
self.check_code_challenge()
if self.request:
raise AuthorizeError(
self.redirect_uri, "request_not_supported", self.grant_type, self.state
)
self.check_nonce()
self.check_code_challenge()
def check_grant(self):
"""Check grant"""
@ -630,6 +630,7 @@ class OAuthFulfillmentStage(StageView):
if self.params.response_type in [
ResponseTypes.ID_TOKEN_TOKEN,
ResponseTypes.CODE_ID_TOKEN_TOKEN,
ResponseTypes.ID_TOKEN,
ResponseTypes.CODE_TOKEN,
]:
query_fragment["access_token"] = token.token

View File

@ -2,11 +2,13 @@
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.contrib.auth.signals import user_logged_out
from django.core.cache import cache
from django.db.models.signals import post_delete, post_save, pre_delete
from django.dispatch import receiver
from django.http import HttpRequest
from authentik.core.models import AuthenticatedSession
from authentik.core.models import AuthenticatedSession, User
from authentik.providers.rac.api.endpoints import user_endpoint_cache_key
from authentik.providers.rac.consumer_client import (
RAC_CLIENT_GROUP_SESSION,
@ -15,6 +17,21 @@ from authentik.providers.rac.consumer_client import (
from authentik.providers.rac.models import ConnectionToken, Endpoint
@receiver(user_logged_out)
def user_logged_out_session(sender, request: HttpRequest, user: User, **_):
"""Disconnect any open RAC connections"""
if not request.session or not request.session.session_key:
return
layer = get_channel_layer()
async_to_sync(layer.group_send)(
RAC_CLIENT_GROUP_SESSION
% {
"session": request.session.session_key,
},
{"type": "event.disconnect", "reason": "session_logout"},
)
@receiver(pre_delete, sender=AuthenticatedSession)
def user_session_deleted(sender, instance: AuthenticatedSession, **_):
layer = get_channel_layer()

View File

@ -49,7 +49,6 @@ class TestEndpointsAPI(APITestCase):
self.assertJSONEqual(
response.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,
@ -102,7 +101,6 @@ class TestEndpointsAPI(APITestCase):
self.assertJSONEqual(
response.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,

View File

@ -20,9 +20,6 @@ from authentik.lib.utils.time import timedelta_from_string
from authentik.policies.engine import PolicyEngine
from authentik.policies.views import PolicyAccessView
from authentik.providers.rac.models import ConnectionToken, Endpoint, RACProvider
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
PLAN_CONNECTION_SETTINGS = "connection_settings"
class RACStartView(PolicyAccessView):
@ -112,15 +109,10 @@ class RACFinalStage(RedirectStage):
return super().dispatch(request, *args, **kwargs)
def get_challenge(self, *args, **kwargs) -> RedirectChallenge:
settings = self.executor.plan.context.get(PLAN_CONNECTION_SETTINGS)
if not settings:
settings = self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}).get(
PLAN_CONNECTION_SETTINGS
)
token = ConnectionToken.objects.create(
provider=self.provider,
endpoint=self.endpoint,
settings=settings or {},
settings=self.executor.plan.context.get("connection_settings", {}),
session=self.request.session["authenticatedsession"],
expires=now() + timedelta_from_string(self.provider.connection_expiry),
expiring=True,

View File

@ -5,6 +5,7 @@ from itertools import batched
from django.db import transaction
from pydantic import ValidationError
from pydanticscim.group import GroupMember
from pydanticscim.responses import PatchOp
from authentik.core.models import Group
from authentik.lib.sync.mapper import PropertyMappingManager
@ -19,12 +20,7 @@ from authentik.providers.scim.clients.base import SCIMClient
from authentik.providers.scim.clients.exceptions import (
SCIMRequestException,
)
from authentik.providers.scim.clients.schema import (
SCIM_GROUP_SCHEMA,
PatchOp,
PatchOperation,
PatchRequest,
)
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
from authentik.providers.scim.models import (
SCIMMapping,

View File

@ -1,7 +1,5 @@
"""Custom SCIM schemas"""
from enum import Enum
from pydantic import Field
from pydanticscim.group import Group as BaseGroup
from pydanticscim.responses import PatchOperation as BasePatchOperation
@ -67,21 +65,6 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
)
class PatchOp(str, Enum):
replace = "replace"
remove = "remove"
add = "add"
@classmethod
def _missing_(cls, value):
value = value.lower()
for member in cls:
if member.lower() == value:
return member
return None
class PatchRequest(BasePatchRequest):
"""PatchRequest which correctly sets schemas"""
@ -91,7 +74,6 @@ class PatchRequest(BasePatchRequest):
class PatchOperation(BasePatchOperation):
"""PatchOperation with optional path"""
op: PatchOp
path: str | None

View File

@ -44,7 +44,6 @@ class TestRBACRoleAPI(APITestCase):
self.assertJSONEqual(
res.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,

View File

@ -46,7 +46,6 @@ class TestRBACUserAPI(APITestCase):
self.assertJSONEqual(
res.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,

View File

@ -38,7 +38,6 @@ class TestAPIPerms(APITestCase):
self.assertJSONEqual(
res.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,
@ -74,7 +73,6 @@ class TestAPIPerms(APITestCase):
self.assertJSONEqual(
res.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,

View File

@ -9,14 +9,13 @@ https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/
import django
from channels.routing import ProtocolTypeRouter, URLRouter
from defusedxml import defuse_stdlib
from django.core.asgi import get_asgi_application
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
from authentik.root.setup import setup
# DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py
setup()
defuse_stdlib()
django.setup()

View File

@ -27,7 +27,7 @@ from structlog.stdlib import get_logger
from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp
from authentik import get_full_version
from authentik.lib.sentry import should_ignore_exception
from authentik.lib.sentry import before_send
from authentik.lib.utils.errors import exception_to_string
# set the default Django settings module for the 'celery' program.
@ -81,7 +81,7 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar
LOGGER.warning("Task failure", task_id=task_id.replace("-", ""), exc=exception)
CTX_TASK_ID.set(...)
if not should_ignore_exception(exception):
if before_send({}, {"exc_info": (None, exception, None)}) is not None:
Event.new(
EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id
).save()

View File

@ -1,49 +1,13 @@
"""authentik database backend"""
from django.core.checks import Warning
from django.db.backends.base.validation import BaseDatabaseValidation
from django_tenants.postgresql_backend.base import DatabaseWrapper as BaseDatabaseWrapper
from authentik.lib.config import CONFIG
class DatabaseValidation(BaseDatabaseValidation):
def check(self, **kwargs):
return self._check_encoding()
def _check_encoding(self):
"""Throw a warning when the server_encoding is not UTF-8 or
server_encoding and client_encoding are mismatched"""
messages = []
with self.connection.cursor() as cursor:
cursor.execute("SHOW server_encoding;")
server_encoding = cursor.fetchone()[0]
cursor.execute("SHOW client_encoding;")
client_encoding = cursor.fetchone()[0]
if server_encoding != client_encoding:
messages.append(
Warning(
"PostgreSQL Server and Client encoding are mismatched: Server: "
f"{server_encoding}, Client: {client_encoding}",
id="ak.db.W001",
)
)
if server_encoding != "UTF8":
messages.append(
Warning(
f"PostgreSQL Server encoding is not UTF8: {server_encoding}",
id="ak.db.W002",
)
)
return messages
class DatabaseWrapper(BaseDatabaseWrapper):
"""database backend which supports rotating credentials"""
validation_class = DatabaseValidation
def get_connection_params(self):
"""Refresh DB credentials before getting connection params"""
conn_params = super().get_connection_params()

View File

@ -446,8 +446,6 @@ _DISALLOWED_ITEMS = [
"MIDDLEWARE",
"AUTHENTICATION_BACKENDS",
"CELERY",
"SPECTACULAR_SETTINGS",
"REST_FRAMEWORK",
]
SILENCED_SYSTEM_CHECKS = [
@ -470,8 +468,6 @@ def _update_settings(app_path: str):
TENANT_APPS.extend(getattr(settings_module, "TENANT_APPS", []))
MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", []))
AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", []))
SPECTACULAR_SETTINGS.update(getattr(settings_module, "SPECTACULAR_SETTINGS", {}))
REST_FRAMEWORK.update(getattr(settings_module, "REST_FRAMEWORK", {}))
CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {}))
for _attr in dir(settings_module):
if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS:

View File

@ -1,26 +0,0 @@
import os
import warnings
from cryptography.hazmat.backends.openssl.backend import backend
from defusedxml import defuse_stdlib
from authentik.lib.config import CONFIG
def setup():
warnings.filterwarnings("ignore", "SelectableGroups dict interface")
warnings.filterwarnings(
"ignore",
"defusedxml.lxml is no longer supported and will be removed in a future release.",
)
warnings.filterwarnings(
"ignore",
"defusedxml.cElementTree is deprecated, import from defusedxml.ElementTree instead.",
)
defuse_stdlib()
if CONFIG.get_bool("compliance.fips.enabled", False):
backend._enable_fips()
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")

View File

@ -3,40 +3,21 @@
import os
from argparse import ArgumentParser
from unittest import TestCase
from unittest.mock import patch
import pytest
from django.conf import settings
from django.contrib.contenttypes.models import ContentType
from django.test.runner import DiscoverRunner
from structlog.stdlib import get_logger
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
from authentik.lib.config import CONFIG
from authentik.lib.sentry import sentry_init
from authentik.root.signals import post_startup, pre_startup, startup
from tests.e2e.utils import get_docker_tag
# globally set maxDiff to none to show full assert error
TestCase.maxDiff = None
def get_docker_tag() -> str:
"""Get docker-tag based off of CI variables"""
env_pr_branch = "GITHUB_HEAD_REF"
default_branch = "GITHUB_REF"
branch_name = os.environ.get(default_branch, "main")
if os.environ.get(env_pr_branch, "") != "":
branch_name = os.environ[env_pr_branch]
branch_name = branch_name.replace("refs/heads/", "").replace("/", "-")
return f"gh-{branch_name}"
def patched__get_ct_cached(app_label, codename):
"""Caches `ContentType` instances like its `QuerySet` does."""
return ContentType.objects.get(app_label=app_label, permission__codename=codename)
class PytestTestRunner(DiscoverRunner): # pragma: no cover
"""Runs pytest to discover and run tests."""
@ -78,9 +59,6 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover
for key, value in test_config.items():
CONFIG.set(key, value)
ASN_CONTEXT_PROCESSOR.load()
GEOIP_CONTEXT_PROCESSOR.load()
sentry_init()
self.logger.debug("Test environment configured")
@ -171,9 +149,8 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover
return 1
self.logger.info("Running tests", test_files=self.args)
with patch("guardian.shortcuts._get_ct_cached", patched__get_ct_cached):
try:
return pytest.main(self.args)
except Exception as e:
self.logger.error("Error running tests", error=str(e), test_files=self.args)
return 1
try:
return pytest.main(self.args)
except Exception as e:
self.logger.error("Error running tests", error=str(e), test_files=self.args)
return 1

View File

@ -71,31 +71,37 @@ def ldap_sync_single(source_pk: str):
return
# Delete all sync tasks from the cache
DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete()
# The order of these operations needs to be preserved as each depends on the previous one(s)
# 1. User and group sync can happen simultaneously
# 2. Membership sync needs to run afterwards
# 3. Finally, user and group deletions can happen simultaneously
user_group_sync = ldap_sync_paginator(source, UserLDAPSynchronizer) + ldap_sync_paginator(
source, GroupLDAPSynchronizer
task = chain(
# User and group sync can happen at once, they have no dependencies on each other
group(
ldap_sync_paginator(source, UserLDAPSynchronizer)
+ ldap_sync_paginator(source, GroupLDAPSynchronizer),
),
# Membership sync needs to run afterwards
group(
ldap_sync_paginator(source, MembershipLDAPSynchronizer),
),
# Finally, deletions. What we'd really like to do here is something like
# ```
# user_identifiers = <ldap query>
# User.objects.exclude(
# usersourceconnection__identifier__in=user_uniqueness_identifiers,
# ).delete()
# ```
# This runs into performance issues in large installations. So instead we spread the
# work out into three steps:
# 1. Get every object from the LDAP source.
# 2. Mark every object as "safe" in the database. This is quick, but any error could
# mean deleting users which should not be deleted, so we do it immediately, in
# large chunks, and only queue the deletion step afterwards.
# 3. Delete every unmarked item. This is slow, so we spread it over many tasks in
# small chunks.
group(
ldap_sync_paginator(source, UserLDAPForwardDeletion)
+ ldap_sync_paginator(source, GroupLDAPForwardDeletion),
),
)
membership_sync = ldap_sync_paginator(source, MembershipLDAPSynchronizer)
user_group_deletion = ldap_sync_paginator(
source, UserLDAPForwardDeletion
) + ldap_sync_paginator(source, GroupLDAPForwardDeletion)
# Celery is buggy with empty groups, so we are careful only to add non-empty groups.
# See https://github.com/celery/celery/issues/9772
task_groups = []
if user_group_sync:
task_groups.append(group(user_group_sync))
if membership_sync:
task_groups.append(group(membership_sync))
if user_group_deletion:
task_groups.append(group(user_group_deletion))
all_tasks = chain(task_groups)
all_tasks()
task()
def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list:

View File

@ -1,277 +0,0 @@
"""Test SCIM Group"""
from json import dumps
from uuid import uuid4
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.models import Group
from authentik.core.tests.utils import create_test_user
from authentik.events.models import Event, EventAction
from authentik.lib.generators import generate_id
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
from authentik.sources.scim.models import (
SCIMSource,
SCIMSourceGroup,
)
from authentik.sources.scim.views.v2.base import SCIM_CONTENT_TYPE
class TestSCIMGroups(APITestCase):
"""Test SCIM Group view"""
def setUp(self) -> None:
self.source = SCIMSource.objects.create(name=generate_id(), slug=generate_id())
def test_group_list(self):
"""Test full group list"""
response = self.client.get(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 200)
def test_group_list_single(self):
"""Test full group list (single group)"""
group = Group.objects.create(name=generate_id())
user = create_test_user()
group.users.add(user)
SCIMSourceGroup.objects.create(
source=self.source,
group=group,
id=str(uuid4()),
)
response = self.client.get(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
"group_id": str(group.pk),
},
),
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=200)
SCIMGroupSchema.model_validate_json(response.content, strict=True)
def test_group_create(self):
"""Test group create"""
ext_id = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
data=dumps({"displayName": generate_id(), "externalId": ext_id}),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 201)
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
self.assertTrue(
Event.objects.filter(
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
).exists()
)
def test_group_create_members(self):
"""Test group create"""
user = create_test_user()
ext_id = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
data=dumps(
{
"displayName": generate_id(),
"externalId": ext_id,
"members": [{"value": str(user.uuid)}],
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 201)
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
self.assertTrue(
Event.objects.filter(
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
).exists()
)
def test_group_create_members_empty(self):
"""Test group create"""
ext_id = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
data=dumps({"displayName": generate_id(), "externalId": ext_id, "members": []}),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 201)
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
self.assertTrue(
Event.objects.filter(
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
).exists()
)
def test_group_create_duplicate(self):
"""Test group create (duplicate)"""
group = Group.objects.create(name=generate_id())
existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
ext_id = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
data=dumps(
{"displayName": generate_id(), "externalId": ext_id, "id": str(existing.group.pk)}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 409)
self.assertJSONEqual(
response.content,
{
"detail": "Group with ID exists already.",
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"scimType": "uniqueness",
"status": 409,
},
)
def test_group_update(self):
"""Test group update"""
group = Group.objects.create(name=generate_id())
existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
ext_id = generate_id()
response = self.client.put(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
data=dumps(
{"displayName": generate_id(), "externalId": ext_id, "id": str(existing.pk)}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=200)
def test_group_update_non_existent(self):
"""Test group update"""
ext_id = generate_id()
response = self.client.put(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
"group_id": str(uuid4()),
},
),
data=dumps({"displayName": generate_id(), "externalId": ext_id, "id": ""}),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=404)
self.assertJSONEqual(
response.content,
{
"detail": "Group not found.",
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"status": 404,
},
)
def test_group_patch_add(self):
"""Test group patch"""
user = create_test_user()
group = Group.objects.create(name=generate_id())
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
response = self.client.patch(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
data=dumps(
{
"Operations": [
{
"op": "Add",
"path": "members",
"value": {"value": str(user.uuid)},
}
]
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=200)
self.assertTrue(group.users.filter(pk=user.pk).exists())
def test_group_patch_remove(self):
"""Test group patch"""
user = create_test_user()
group = Group.objects.create(name=generate_id())
group.users.add(user)
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
response = self.client.patch(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
data=dumps(
{
"Operations": [
{
"op": "remove",
"path": "members",
"value": {"value": str(user.uuid)},
}
]
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=200)
self.assertFalse(group.users.filter(pk=user.pk).exists())
def test_group_delete(self):
"""Test group delete"""
group = Group.objects.create(name=generate_id())
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
response = self.client.delete(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=204)

View File

@ -177,51 +177,3 @@ class TestSCIMUsers(APITestCase):
SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"],
"0123456789",
)
def test_user_update(self):
"""Test user update"""
user = create_test_user()
existing = SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
ext_id = generate_id()
response = self.client.put(
reverse(
"authentik_sources_scim:v2-users",
kwargs={
"source_slug": self.source.slug,
"user_id": str(user.uuid),
},
),
data=dumps(
{
"id": str(existing.pk),
"userName": generate_id(),
"externalId": ext_id,
"emails": [
{
"primary": True,
"value": user.email,
}
],
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 200)
def test_user_delete(self):
"""Test user delete"""
user = create_test_user()
SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
response = self.client.delete(
reverse(
"authentik_sources_scim:v2-users",
kwargs={
"source_slug": self.source.slug,
"user_id": str(user.uuid),
},
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 204)

View File

@ -8,7 +8,6 @@ from rest_framework.authentication import BaseAuthentication, get_authorization_
from rest_framework.request import Request
from rest_framework.views import APIView
from authentik.core.middleware import CTX_AUTH_VIA
from authentik.core.models import Token, TokenIntents, User
from authentik.sources.scim.models import SCIMSource
@ -27,7 +26,6 @@ class SCIMTokenAuth(BaseAuthentication):
_username, _, password = b64decode(key.encode()).decode().partition(":")
token = self.check_token(password, source_slug)
if token:
CTX_AUTH_VIA.set("scim_basic")
return (token.user, token)
return None
@ -54,5 +52,4 @@ class SCIMTokenAuth(BaseAuthentication):
token = self.check_token(key, source_slug)
if not token:
return None
CTX_AUTH_VIA.set("scim_token")
return (token.user, token)

View File

@ -1,11 +1,13 @@
"""SCIM Utils"""
from typing import Any
from urllib.parse import urlparse
from django.conf import settings
from django.core.paginator import Page, Paginator
from django.db.models import Q, QuerySet
from django.http import HttpRequest
from django.urls import resolve
from rest_framework.parsers import JSONParser
from rest_framework.permissions import IsAuthenticated
from rest_framework.renderers import JSONRenderer
@ -44,7 +46,7 @@ class SCIMView(APIView):
logger: BoundLogger
permission_classes = [IsAuthenticated]
parser_classes = [SCIMParser, JSONParser]
parser_classes = [SCIMParser]
renderer_classes = [SCIMRenderer]
def setup(self, request: HttpRequest, *args: Any, **kwargs: Any) -> None:
@ -54,6 +56,28 @@ class SCIMView(APIView):
def get_authenticators(self):
return [SCIMTokenAuth(self)]
def patch_resolve_value(self, raw_value: dict) -> User | Group | None:
"""Attempt to resolve a raw `value` attribute of a patch operation into
a database model"""
model = User
query = {}
if "$ref" in raw_value:
url = urlparse(raw_value["$ref"])
if match := resolve(url.path):
if match.url_name == "v2-users":
model = User
query = {"pk": int(match.kwargs["user_id"])}
elif "type" in raw_value:
match raw_value["type"]:
case "User":
model = User
query = {"pk": int(raw_value["value"])}
case "Group":
model = Group
else:
return None
return model.objects.filter(**query).first()
def filter_parse(self, request: Request):
"""Parse the path of a Patch Operation"""
path = request.query_params.get("filter")

View File

@ -1,58 +0,0 @@
from enum import Enum
from pydanticscim.responses import SCIMError as BaseSCIMError
from rest_framework.exceptions import ValidationError
class SCIMErrorTypes(Enum):
invalid_filter = "invalidFilter"
too_many = "tooMany"
uniqueness = "uniqueness"
mutability = "mutability"
invalid_syntax = "invalidSyntax"
invalid_path = "invalidPath"
no_target = "noTarget"
invalid_value = "invalidValue"
invalid_vers = "invalidVers"
sensitive = "sensitive"
class SCIMError(BaseSCIMError):
scimType: SCIMErrorTypes | None = None
detail: str | None = None
class SCIMValidationError(ValidationError):
status_code = 400
default_detail = SCIMError(scimType=SCIMErrorTypes.invalid_syntax, status=400)
def __init__(self, detail: SCIMError | None):
if detail is None:
detail = self.default_detail
detail.status = self.status_code
self.detail = detail.model_dump(mode="json", exclude_none=True)
class SCIMConflictError(SCIMValidationError):
status_code = 409
def __init__(self, detail: str):
super().__init__(
SCIMError(
detail=detail,
scimType=SCIMErrorTypes.uniqueness,
status=self.status_code,
)
)
class SCIMNotFoundError(SCIMValidationError):
status_code = 404
def __init__(self, detail: str):
super().__init__(
SCIMError(
detail=detail,
status=self.status_code,
)
)

View File

@ -4,25 +4,19 @@ from uuid import uuid4
from django.db.models import Q
from django.db.transaction import atomic
from django.http import QueryDict
from django.http import Http404, QueryDict
from django.urls import reverse
from pydantic import ValidationError as PydanticValidationError
from pydanticscim.group import GroupMember
from rest_framework.exceptions import ValidationError
from rest_framework.request import Request
from rest_framework.response import Response
from scim2_filter_parser.attr_paths import AttrPath
from authentik.core.models import Group, User
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOp, PatchOperation
from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
from authentik.providers.scim.clients.schema import Group as SCIMGroupModel
from authentik.sources.scim.models import SCIMSourceGroup
from authentik.sources.scim.views.v2.base import SCIMObjectView
from authentik.sources.scim.views.v2.exceptions import (
SCIMConflictError,
SCIMNotFoundError,
SCIMValidationError,
)
class GroupsView(SCIMObjectView):
@ -33,7 +27,7 @@ class GroupsView(SCIMObjectView):
def group_to_scim(self, scim_group: SCIMSourceGroup) -> dict:
"""Convert Group to SCIM data"""
payload = SCIMGroupModel(
schemas=[SCIM_GROUP_SCHEMA],
schemas=[SCIM_USER_SCHEMA],
id=str(scim_group.group.pk),
externalId=scim_group.id,
displayName=scim_group.group.name,
@ -64,7 +58,7 @@ class GroupsView(SCIMObjectView):
if group_id:
connection = base_query.filter(source=self.source, group__group_uuid=group_id).first()
if not connection:
raise SCIMNotFoundError("Group not found.")
raise Http404
return Response(self.group_to_scim(connection))
connections = (
base_query.filter(source=self.source).order_by("pk").filter(self.filter_parse(request))
@ -125,7 +119,7 @@ class GroupsView(SCIMObjectView):
).first()
if connection:
self.logger.debug("Found existing group")
raise SCIMConflictError("Group with ID exists already.")
return Response(status=409)
connection = self.update_group(None, request.data)
return Response(self.group_to_scim(connection), status=201)
@ -135,44 +129,10 @@ class GroupsView(SCIMObjectView):
source=self.source, group__group_uuid=group_id
).first()
if not connection:
raise SCIMNotFoundError("Group not found.")
raise Http404
connection = self.update_group(connection, request.data)
return Response(self.group_to_scim(connection), status=200)
@atomic
def patch(self, request: Request, group_id: str, **kwargs) -> Response:
"""Patch group handler"""
connection = SCIMSourceGroup.objects.filter(
source=self.source, group__group_uuid=group_id
).first()
if not connection:
raise SCIMNotFoundError("Group not found.")
for _op in request.data.get("Operations", []):
operation = PatchOperation.model_validate(_op)
if operation.op.lower() not in ["add", "remove", "replace"]:
raise SCIMValidationError()
attr_path = AttrPath(f'{operation.path} eq ""', {})
if attr_path.first_path == ("members", None, None):
# FIXME: this can probably be de-duplicated
if operation.op == PatchOp.add:
if not isinstance(operation.value, list):
operation.value = [operation.value]
query = Q()
for member in operation.value:
query |= Q(uuid=member["value"])
if query:
connection.group.users.add(*User.objects.filter(query))
elif operation.op == PatchOp.remove:
if not isinstance(operation.value, list):
operation.value = [operation.value]
query = Q()
for member in operation.value:
query |= Q(uuid=member["value"])
if query:
connection.group.users.remove(*User.objects.filter(query))
return Response(self.group_to_scim(connection), status=200)
@atomic
def delete(self, request: Request, group_id: str, **kwargs) -> Response:
"""Delete group handler"""
@ -180,7 +140,7 @@ class GroupsView(SCIMObjectView):
source=self.source, group__group_uuid=group_id
).first()
if not connection:
raise SCIMNotFoundError("Group not found.")
raise Http404
connection.group.delete()
connection.delete()
return Response(status=204)

View File

@ -1,11 +1,11 @@
"""SCIM Meta views"""
from django.http import Http404
from django.urls import reverse
from rest_framework.request import Request
from rest_framework.response import Response
from authentik.sources.scim.views.v2.base import SCIMView
from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError
class ResourceTypesView(SCIMView):
@ -138,7 +138,7 @@ class ResourceTypesView(SCIMView):
resource = [x for x in resource_types if x.get("id") == resource_type]
if resource:
return Response(resource[0])
raise SCIMNotFoundError("Resource not found.")
raise Http404
return Response(
{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],

View File

@ -3,12 +3,12 @@
from json import loads
from django.conf import settings
from django.http import Http404
from django.urls import reverse
from rest_framework.request import Request
from rest_framework.response import Response
from authentik.sources.scim.views.v2.base import SCIMView
from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError
with open(
settings.BASE_DIR / "authentik" / "sources" / "scim" / "schemas" / "schema.json",
@ -44,7 +44,7 @@ class SchemaView(SCIMView):
schema = [x for x in schemas if x.get("id") == schema_uri]
if schema:
return Response(schema[0])
raise SCIMNotFoundError("Schema not found.")
raise Http404
return Response(
{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],

View File

@ -33,8 +33,6 @@ class ServiceProviderConfigView(SCIMView):
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"],
"authenticationSchemes": auth_schemas,
# We only support patch for groups currently, so don't broadly advertise it.
# Implementations that require Group patch will use it regardless of this flag.
"patch": {"supported": False},
"bulk": {"supported": False, "maxOperations": 0, "maxPayloadSize": 0},
"filter": {

View File

@ -4,7 +4,7 @@ from uuid import uuid4
from django.db.models import Q
from django.db.transaction import atomic
from django.http import QueryDict
from django.http import Http404, QueryDict
from django.urls import reverse
from pydanticscim.user import Email, EmailKind, Name
from rest_framework.exceptions import ValidationError
@ -16,7 +16,6 @@ from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
from authentik.providers.scim.clients.schema import User as SCIMUserModel
from authentik.sources.scim.models import SCIMSourceUser
from authentik.sources.scim.views.v2.base import SCIMObjectView
from authentik.sources.scim.views.v2.exceptions import SCIMConflictError, SCIMNotFoundError
class UsersView(SCIMObjectView):
@ -70,7 +69,7 @@ class UsersView(SCIMObjectView):
.first()
)
if not connection:
raise SCIMNotFoundError("User not found.")
raise Http404
return Response(self.user_to_scim(connection))
connections = (
SCIMSourceUser.objects.filter(source=self.source).select_related("user").order_by("pk")
@ -123,7 +122,7 @@ class UsersView(SCIMObjectView):
).first()
if connection:
self.logger.debug("Found existing user")
raise SCIMConflictError("Group with ID exists already.")
return Response(status=409)
connection = self.update_user(None, request.data)
return Response(self.user_to_scim(connection), status=201)
@ -131,7 +130,7 @@ class UsersView(SCIMObjectView):
"""Update user handler"""
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
if not connection:
raise SCIMNotFoundError("User not found.")
raise Http404
self.update_user(connection, request.data)
return Response(self.user_to_scim(connection), status=200)
@ -140,7 +139,7 @@ class UsersView(SCIMObjectView):
"""Delete user handler"""
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
if not connection:
raise SCIMNotFoundError("User not found.")
raise Http404
connection.user.delete()
connection.delete()
return Response(status=204)

View File

@ -1,7 +1,6 @@
"""Validation stage challenge checking"""
from json import loads
from typing import TYPE_CHECKING
from urllib.parse import urlencode
from django.http import HttpRequest
@ -37,12 +36,10 @@ from authentik.stages.authenticator_email.models import EmailDevice
from authentik.stages.authenticator_sms.models import SMSDevice
from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses
from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice
from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id
LOGGER = get_logger()
if TYPE_CHECKING:
from authentik.stages.authenticator_validate.stage import AuthenticatorValidateStageView
class DeviceChallenge(PassiveSerializer):
@ -55,11 +52,11 @@ class DeviceChallenge(PassiveSerializer):
def get_challenge_for_device(
stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage, device: Device
request: HttpRequest, stage: AuthenticatorValidateStage, device: Device
) -> dict:
"""Generate challenge for a single device"""
if isinstance(device, WebAuthnDevice):
return get_webauthn_challenge(stage_view, stage, device)
return get_webauthn_challenge(request, stage, device)
if isinstance(device, EmailDevice):
return {"email": mask_email(device.email)}
# Code-based challenges have no hints
@ -67,30 +64,26 @@ def get_challenge_for_device(
def get_webauthn_challenge_without_user(
stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage
request: HttpRequest, stage: AuthenticatorValidateStage
) -> dict:
"""Same as `get_webauthn_challenge`, but allows any client device. We can then later check
who the device belongs to."""
stage_view.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None)
request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
authentication_options = generate_authentication_options(
rp_id=get_rp_id(stage_view.request),
rp_id=get_rp_id(request),
allow_credentials=[],
user_verification=UserVerificationRequirement(stage.webauthn_user_verification),
)
stage_view.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = (
authentication_options.challenge
)
request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge
return loads(options_to_json(authentication_options))
def get_webauthn_challenge(
stage_view: "AuthenticatorValidateStageView",
stage: AuthenticatorValidateStage,
device: WebAuthnDevice | None = None,
request: HttpRequest, stage: AuthenticatorValidateStage, device: WebAuthnDevice | None = None
) -> dict:
"""Send the client a challenge that we'll check later"""
stage_view.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None)
request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
allowed_credentials = []
@ -101,14 +94,12 @@ def get_webauthn_challenge(
allowed_credentials.append(user_device.descriptor)
authentication_options = generate_authentication_options(
rp_id=get_rp_id(stage_view.request),
rp_id=get_rp_id(request),
allow_credentials=allowed_credentials,
user_verification=UserVerificationRequirement(stage.webauthn_user_verification),
)
stage_view.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = (
authentication_options.challenge
)
request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge
return loads(options_to_json(authentication_options))
@ -155,7 +146,7 @@ def validate_challenge_code(code: str, stage_view: StageView, user: User) -> Dev
def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -> Device:
"""Validate WebAuthn Challenge"""
request = stage_view.request
challenge = stage_view.executor.plan.context.get(PLAN_CONTEXT_WEBAUTHN_CHALLENGE)
challenge = request.session.get(SESSION_KEY_WEBAUTHN_CHALLENGE)
stage: AuthenticatorValidateStage = stage_view.executor.current_stage
try:
credential = parse_authentication_credential_json(data)

View File

@ -224,7 +224,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
data={
"device_class": device_class,
"device_uid": device.pk,
"challenge": get_challenge_for_device(self, stage, device),
"challenge": get_challenge_for_device(self.request, stage, device),
"last_used": device.last_used,
}
)
@ -243,7 +243,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
"device_class": DeviceClasses.WEBAUTHN,
"device_uid": -1,
"challenge": get_webauthn_challenge_without_user(
self,
self.request,
self.executor.current_stage,
),
"last_used": None,

View File

@ -31,7 +31,7 @@ from authentik.stages.authenticator_webauthn.models import (
WebAuthnDevice,
WebAuthnDeviceType,
)
from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import
from authentik.stages.identification.models import IdentificationStage, UserFields
from authentik.stages.user_login.models import UserLoginStage
@ -103,11 +103,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
device_classes=[DeviceClasses.WEBAUTHN],
webauthn_user_verification=UserVerification.PREFERRED,
)
plan = FlowPlan("")
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request
)
challenge = get_challenge_for_device(stage_view, stage, webauthn_device)
challenge = get_challenge_for_device(request, stage, webauthn_device)
del challenge["challenge"]
self.assertEqual(
challenge,
@ -126,9 +122,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
with self.assertRaises(ValidationError):
validate_challenge_webauthn(
{},
StageView(FlowExecutorView(current_stage=stage, plan=plan), request=request),
self.user,
{}, StageView(FlowExecutorView(current_stage=stage), request=request), self.user
)
def test_device_challenge_webauthn_restricted(self):
@ -199,35 +193,22 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
sign_count=0,
rp_id=generate_id(),
)
plan = FlowPlan("")
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
)
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request
)
challenge = get_challenge_for_device(stage_view, stage, webauthn_device)
challenge = get_challenge_for_device(request, stage, webauthn_device)
webauthn_challenge = request.session[SESSION_KEY_WEBAUTHN_CHALLENGE]
self.assertEqual(
challenge["allowCredentials"],
[
{
"id": "QKZ97ASJAOIDyipAs6mKUxDUZgDrWrbAsUb5leL7-oU",
"type": "public-key",
}
],
)
self.assertIsNotNone(challenge["challenge"])
self.assertEqual(
challenge["rpId"],
"testserver",
)
self.assertEqual(
challenge["timeout"],
60000,
)
self.assertEqual(
challenge["userVerification"],
"preferred",
challenge,
{
"allowCredentials": [
{
"id": "QKZ97ASJAOIDyipAs6mKUxDUZgDrWrbAsUb5leL7-oU",
"type": "public-key",
}
],
"challenge": bytes_to_base64url(webauthn_challenge),
"rpId": "testserver",
"timeout": 60000,
"userVerification": "preferred",
},
)
def test_get_challenge_userless(self):
@ -247,16 +228,18 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
sign_count=0,
rp_id=generate_id(),
)
plan = FlowPlan("")
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request
challenge = get_webauthn_challenge_without_user(request, stage)
webauthn_challenge = request.session[SESSION_KEY_WEBAUTHN_CHALLENGE]
self.assertEqual(
challenge,
{
"allowCredentials": [],
"challenge": bytes_to_base64url(webauthn_challenge),
"rpId": "testserver",
"timeout": 60000,
"userVerification": "preferred",
},
)
challenge = get_webauthn_challenge_without_user(stage_view, stage)
self.assertEqual(challenge["allowCredentials"], [])
self.assertIsNotNone(challenge["challenge"])
self.assertEqual(challenge["rpId"], "testserver")
self.assertEqual(challenge["timeout"], 60000)
self.assertEqual(challenge["userVerification"], "preferred")
def test_validate_challenge_unrestricted(self):
"""Test webauthn authentication (unrestricted webauthn device)"""
@ -292,10 +275,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
"last_used": None,
}
]
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ"
)
session[SESSION_KEY_PLAN] = plan
session.save()
response = self.client.post(
@ -369,10 +352,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
"last_used": None,
}
]
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ"
)
session[SESSION_KEY_PLAN] = plan
session.save()
response = self.client.post(
@ -450,10 +433,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
"last_used": None,
}
]
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
)
session[SESSION_KEY_PLAN] = plan
session.save()
response = self.client.post(
@ -513,14 +496,17 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
not_configured_action=NotConfiguredAction.CONFIGURE,
device_classes=[DeviceClasses.WEBAUTHN],
)
plan = FlowPlan(flow.pk.hex)
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=flow, current_stage=stage), request=request
)
request = get_request("/")
request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
)
request.session.save()
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=flow, current_stage=stage, plan=plan), request=request
FlowExecutorView(flow=flow, current_stage=stage), request=request
)
request.META["SERVER_NAME"] = "localhost"
request.META["SERVER_PORT"] = "9000"

View File

@ -25,7 +25,6 @@ class AuthenticatorWebAuthnStageSerializer(StageSerializer):
"resident_key_requirement",
"device_type_restrictions",
"device_type_restrictions_obj",
"max_attempts",
]

File diff suppressed because one or more lines are too long

View File

@ -1,21 +0,0 @@
# Generated by Django 5.1.11 on 2025-06-13 22:41
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
(
"authentik_stages_authenticator_webauthn",
"0012_webauthndevice_created_webauthndevice_last_updated_and_more",
),
]
operations = [
migrations.AddField(
model_name="authenticatorwebauthnstage",
name="max_attempts",
field=models.PositiveIntegerField(default=0),
),
]

View File

@ -84,8 +84,6 @@ class AuthenticatorWebAuthnStage(ConfigurableStage, FriendlyNamedStage, Stage):
device_type_restrictions = models.ManyToManyField("WebAuthnDeviceType", blank=True)
max_attempts = models.PositiveIntegerField(default=0)
@property
def serializer(self) -> type[BaseSerializer]:
from authentik.stages.authenticator_webauthn.api.stages import (

View File

@ -5,13 +5,12 @@ from uuid import UUID
from django.http import HttpRequest, HttpResponse
from django.http.request import QueryDict
from django.utils.translation import gettext as __
from django.utils.translation import gettext_lazy as _
from rest_framework.fields import CharField
from rest_framework.serializers import ValidationError
from webauthn import options_to_json
from webauthn.helpers.bytes_to_base64url import bytes_to_base64url
from webauthn.helpers.exceptions import WebAuthnException
from webauthn.helpers.exceptions import InvalidRegistrationResponse
from webauthn.helpers.structs import (
AttestationConveyancePreference,
AuthenticatorAttachment,
@ -42,8 +41,7 @@ from authentik.stages.authenticator_webauthn.models import (
)
from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id
PLAN_CONTEXT_WEBAUTHN_CHALLENGE = "goauthentik.io/stages/authenticator_webauthn/challenge"
PLAN_CONTEXT_WEBAUTHN_ATTEMPT = "goauthentik.io/stages/authenticator_webauthn/attempt"
SESSION_KEY_WEBAUTHN_CHALLENGE = "authentik/stages/authenticator_webauthn/challenge"
class AuthenticatorWebAuthnChallenge(WithUserInfoChallenge):
@ -64,7 +62,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse):
def validate_response(self, response: dict) -> dict:
"""Validate webauthn challenge response"""
challenge = self.stage.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE]
challenge = self.request.session[SESSION_KEY_WEBAUTHN_CHALLENGE]
try:
registration: VerifiedRegistration = verify_registration_response(
@ -73,7 +71,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse):
expected_rp_id=get_rp_id(self.request),
expected_origin=get_origin(self.request),
)
except WebAuthnException as exc:
except InvalidRegistrationResponse as exc:
self.stage.logger.warning("registration failed", exc=exc)
raise ValidationError(f"Registration failed. Error: {exc}") from None
@ -116,10 +114,9 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
response_class = AuthenticatorWebAuthnChallengeResponse
def get_challenge(self, *args, **kwargs) -> Challenge:
# clear session variables prior to starting a new registration
self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
stage: AuthenticatorWebAuthnStage = self.executor.current_stage
self.executor.plan.context.setdefault(PLAN_CONTEXT_WEBAUTHN_ATTEMPT, 0)
# clear flow variables prior to starting a new registration
self.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None)
user = self.get_pending_user()
# library accepts none so we store null in the database, but if there is a value
@ -142,7 +139,8 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
attestation=AttestationConveyancePreference.DIRECT,
)
self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = registration_options.challenge
self.request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = registration_options.challenge
self.request.session.save()
return AuthenticatorWebAuthnChallenge(
data={
"registration": loads(options_to_json(registration_options)),
@ -155,24 +153,6 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
response.user = self.get_pending_user()
return response
def challenge_invalid(self, response):
stage: AuthenticatorWebAuthnStage = self.executor.current_stage
self.executor.plan.context.setdefault(PLAN_CONTEXT_WEBAUTHN_ATTEMPT, 0)
self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_ATTEMPT] += 1
if (
stage.max_attempts > 0
and self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_ATTEMPT] >= stage.max_attempts
):
return self.executor.stage_invalid(
__(
"Exceeded maximum attempts. "
"Contact your {brand} administrator for help.".format(
brand=self.request.brand.branding_title
)
)
)
return super().challenge_invalid(response)
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
# Webauthn Challenge has already been validated
webauthn_credential: VerifiedRegistration = response.validated_data["response"]
@ -199,3 +179,6 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
else:
return self.executor.stage_invalid("Device with Credential ID already exists.")
return self.executor.stage_ok()
def cleanup(self):
self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)

View File

@ -18,7 +18,7 @@ from authentik.stages.authenticator_webauthn.models import (
WebAuthnDevice,
WebAuthnDeviceType,
)
from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import
@ -57,9 +57,6 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
)
plan: FlowPlan = self.client.session[SESSION_KEY_PLAN]
self.assertEqual(response.status_code, 200)
session = self.client.session
self.assertStageResponse(
@ -73,7 +70,7 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
"name": self.user.username,
"displayName": self.user.name,
},
"challenge": bytes_to_base64url(plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE]),
"challenge": bytes_to_base64url(session[SESSION_KEY_WEBAUTHN_CHALLENGE]),
"pubKeyCredParams": [
{"type": "public-key", "alg": -7},
{"type": "public-key", "alg": -8},
@ -100,11 +97,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
"""Test registration"""
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session.save()
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
@ -149,11 +146,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session.save()
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
@ -212,11 +209,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session.save()
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
@ -262,11 +259,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session.save()
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
@ -301,109 +298,3 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
self.assertEqual(response.status_code, 200)
self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
self.assertTrue(WebAuthnDevice.objects.filter(user=self.user).exists())
def test_register_max_retries(self):
"""Test registration (exceeding max retries)"""
self.stage.max_attempts = 2
self.stage.save()
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session.save()
# first failed request
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
data={
"component": "ak-stage-authenticator-webauthn",
"response": {
"id": "kqnmrVLnDG-OwsSNHkihYZaNz5s",
"rawId": "kqnmrVLnDG-OwsSNHkihYZaNz5s",
"type": "public-key",
"registrationClientExtensions": "{}",
"response": {
"clientDataJSON": (
"eyJ0eXBlIjoid2ViYXV0aG4uY3JlYXRlIiwiY2hhbGxlbmd"
"lIjoiMDNYb2RpNTRnS3NmblA1STlWRmZoYUdYVlZFMk5VeV"
"pwQkJYbnNfSkkteDZWOVJZMlR3MlFteFJKa2hoNzE3NEVrU"
"mF6VW50SXdqTVZZOWJGRzYwTHciLCJvcmlnaW4iOiJodHRw"
"Oi8vbG9jYWxob3N0OjkwMDAiLCJjcm9zc09yaWdpbiI6ZmF"
),
"attestationObject": (
"o2NmbXRkbm9uZWdhdHRTdG10oGhhdXRoRGF0YViYSZYN5Yg"
"OjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2NdAAAAAPv8MA"
"cVTk7MjAtuAgVX170AFJKp5q1S5wxvjsLEjR5IoWGWjc-bp"
"QECAyYgASFYIKtcZHPumH37XHs0IM1v3pUBRIqHVV_SE-Lq"
"2zpJAOVXIlgg74Fg_WdB0kuLYqCKbxogkEPaVtR_iR3IyQFIJAXBzds"
),
},
},
},
SERVER_NAME="localhost",
SERVER_PORT="9000",
)
self.assertEqual(response.status_code, 200)
self.assertStageResponse(
response,
flow=self.flow,
component="ak-stage-authenticator-webauthn",
response_errors={
"response": [
{
"string": (
"Registration failed. Error: Unable to decode "
"client_data_json bytes as JSON"
),
"code": "invalid",
}
]
},
)
self.assertFalse(WebAuthnDevice.objects.filter(user=self.user).exists())
# Second failed request
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
data={
"component": "ak-stage-authenticator-webauthn",
"response": {
"id": "kqnmrVLnDG-OwsSNHkihYZaNz5s",
"rawId": "kqnmrVLnDG-OwsSNHkihYZaNz5s",
"type": "public-key",
"registrationClientExtensions": "{}",
"response": {
"clientDataJSON": (
"eyJ0eXBlIjoid2ViYXV0aG4uY3JlYXRlIiwiY2hhbGxlbmd"
"lIjoiMDNYb2RpNTRnS3NmblA1STlWRmZoYUdYVlZFMk5VeV"
"pwQkJYbnNfSkkteDZWOVJZMlR3MlFteFJKa2hoNzE3NEVrU"
"mF6VW50SXdqTVZZOWJGRzYwTHciLCJvcmlnaW4iOiJodHRw"
"Oi8vbG9jYWxob3N0OjkwMDAiLCJjcm9zc09yaWdpbiI6ZmF"
),
"attestationObject": (
"o2NmbXRkbm9uZWdhdHRTdG10oGhhdXRoRGF0YViYSZYN5Yg"
"OjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2NdAAAAAPv8MA"
"cVTk7MjAtuAgVX170AFJKp5q1S5wxvjsLEjR5IoWGWjc-bp"
"QECAyYgASFYIKtcZHPumH37XHs0IM1v3pUBRIqHVV_SE-Lq"
"2zpJAOVXIlgg74Fg_WdB0kuLYqCKbxogkEPaVtR_iR3IyQFIJAXBzds"
),
},
},
},
SERVER_NAME="localhost",
SERVER_PORT="9000",
)
self.assertEqual(response.status_code, 200)
self.assertStageResponse(
response,
flow=self.flow,
component="ak-stage-access-denied",
error_message=(
"Exceeded maximum attempts. Contact your authentik administrator for help."
),
)
self.assertFalse(WebAuthnDevice.objects.filter(user=self.user).exists())

View File

@ -101,9 +101,9 @@ class BoundSessionMiddleware(SessionMiddleware):
SESSION_KEY_BINDING_GEO, GeoIPBinding.NO_BINDING
)
if configured_binding_net != NetworkBinding.NO_BINDING:
BoundSessionMiddleware.recheck_session_net(configured_binding_net, last_ip, new_ip)
self.recheck_session_net(configured_binding_net, last_ip, new_ip)
if configured_binding_geo != GeoIPBinding.NO_BINDING:
BoundSessionMiddleware.recheck_session_geo(configured_binding_geo, last_ip, new_ip)
self.recheck_session_geo(configured_binding_geo, last_ip, new_ip)
# If we got to this point without any error being raised, we need to
# update the last saved IP to the current one
if SESSION_KEY_BINDING_NET in request.session or SESSION_KEY_BINDING_GEO in request.session:
@ -111,8 +111,7 @@ class BoundSessionMiddleware(SessionMiddleware):
# (== basically requires the user to be logged in)
request.session[request.session.model.Keys.LAST_IP] = new_ip
@staticmethod
def recheck_session_net(binding: NetworkBinding, last_ip: str, new_ip: str):
def recheck_session_net(self, binding: NetworkBinding, last_ip: str, new_ip: str):
"""Check network/ASN binding"""
last_asn = ASN_CONTEXT_PROCESSOR.asn(last_ip)
new_asn = ASN_CONTEXT_PROCESSOR.asn(new_ip)
@ -159,8 +158,7 @@ class BoundSessionMiddleware(SessionMiddleware):
new_ip,
)
@staticmethod
def recheck_session_geo(binding: GeoIPBinding, last_ip: str, new_ip: str):
def recheck_session_geo(self, binding: GeoIPBinding, last_ip: str, new_ip: str):
"""Check GeoIP binding"""
last_geo = GEOIP_CONTEXT_PROCESSOR.city(last_ip)
new_geo = GEOIP_CONTEXT_PROCESSOR.city(new_ip)
@ -181,8 +179,8 @@ class BoundSessionMiddleware(SessionMiddleware):
if last_geo.continent != new_geo.continent:
raise SessionBindingBroken(
"geoip.continent",
last_geo.continent.to_dict(),
new_geo.continent.to_dict(),
last_geo.continent,
new_geo.continent,
last_ip,
new_ip,
)
@ -194,8 +192,8 @@ class BoundSessionMiddleware(SessionMiddleware):
if last_geo.country != new_geo.country:
raise SessionBindingBroken(
"geoip.country",
last_geo.country.to_dict(),
new_geo.country.to_dict(),
last_geo.country,
new_geo.country,
last_ip,
new_ip,
)
@ -204,8 +202,8 @@ class BoundSessionMiddleware(SessionMiddleware):
if last_geo.city != new_geo.city:
raise SessionBindingBroken(
"geoip.city",
last_geo.city.to_dict(),
new_geo.city.to_dict(),
last_geo.city,
new_geo.city,
last_ip,
new_ip,
)

View File

@ -3,7 +3,6 @@
from time import sleep
from unittest.mock import patch
from django.http import HttpRequest
from django.urls import reverse
from django.utils.timezone import now
@ -18,12 +17,7 @@ from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.generators import generate_id
from authentik.lib.utils.time import timedelta_from_string
from authentik.root.middleware import ClientIPMiddleware
from authentik.stages.user_login.middleware import (
BoundSessionMiddleware,
SessionBindingBroken,
logout_extra,
)
from authentik.stages.user_login.models import GeoIPBinding, NetworkBinding, UserLoginStage
from authentik.stages.user_login.models import UserLoginStage
class TestUserLoginStage(FlowTestCase):
@ -198,52 +192,3 @@ class TestUserLoginStage(FlowTestCase):
self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
response = self.client.get(reverse("authentik_api:application-list"))
self.assertEqual(response.status_code, 403)
def test_binding_net_break_log(self):
"""Test logout_extra with exception"""
# IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-ASN-Test.json
for args, expect in [
[[NetworkBinding.BIND_ASN, "8.8.8.8", "8.8.8.8"], ["network.missing"]],
[[NetworkBinding.BIND_ASN, "1.0.0.1", "1.128.0.1"], ["network.asn"]],
[
[NetworkBinding.BIND_ASN_NETWORK, "12.81.96.1", "12.81.128.1"],
["network.asn_network"],
],
[[NetworkBinding.BIND_ASN_NETWORK_IP, "1.0.0.1", "1.0.0.2"], ["network.ip"]],
]:
with self.subTest(args[0]):
with self.assertRaises(SessionBindingBroken) as cm:
BoundSessionMiddleware.recheck_session_net(*args)
self.assertEqual(cm.exception.reason, expect[0])
# Ensure the request can be logged without throwing errors
self.client.force_login(self.user)
request = HttpRequest()
request.session = self.client.session
request.user = self.user
logout_extra(request, cm.exception)
def test_binding_geo_break_log(self):
"""Test logout_extra with exception"""
# IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
for args, expect in [
[[GeoIPBinding.BIND_CONTINENT, "8.8.8.8", "8.8.8.8"], ["geoip.missing"]],
[[GeoIPBinding.BIND_CONTINENT, "2.125.160.216", "67.43.156.1"], ["geoip.continent"]],
[
[GeoIPBinding.BIND_CONTINENT_COUNTRY, "81.2.69.142", "89.160.20.112"],
["geoip.country"],
],
[
[GeoIPBinding.BIND_CONTINENT_COUNTRY_CITY, "2.125.160.216", "81.2.69.142"],
["geoip.city"],
],
]:
with self.subTest(args[0]):
with self.assertRaises(SessionBindingBroken) as cm:
BoundSessionMiddleware.recheck_session_geo(*args)
self.assertEqual(cm.exception.reason, expect[0])
# Ensure the request can be logged without throwing errors
self.client.force_login(self.user)
request = HttpRequest()
request.session = self.client.session
request.user = self.user
logout_extra(request, cm.exception)

View File

@ -1,7 +1,6 @@
"""Serializer for tenants models"""
from django_tenants.utils import get_public_schema_name
from rest_framework.fields import JSONField
from rest_framework.generics import RetrieveUpdateAPIView
from rest_framework.permissions import SAFE_METHODS
@ -13,8 +12,6 @@ from authentik.tenants.models import Tenant
class SettingsSerializer(ModelSerializer):
"""Settings Serializer"""
footer_links = JSONField(required=False)
class Meta:
model = Tenant
fields = [

View File

@ -16,7 +16,6 @@ def check_embedded_outpost_disabled(app_configs, **kwargs):
"Embedded outpost must be disabled when tenants API is enabled.",
hint="Disable embedded outpost by setting outposts.disable_embedded_outpost to "
"True, or disable the tenants API by setting tenants.enabled to False",
id="ak.tenants.E001",
)
]
return []

File diff suppressed because it is too large Load Diff

View File

@ -31,7 +31,7 @@ services:
volumes:
- redis:/data
server:
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2025.6.2}
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2025.6.1}
restart: unless-stopped
command: server
environment:
@ -55,7 +55,7 @@ services:
redis:
condition: service_healthy
worker:
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2025.6.2}
image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2025.6.1}
restart: unless-stopped
command: worker
environment:

8
go.mod
View File

@ -6,7 +6,7 @@ require (
beryju.io/ldap v0.1.0
github.com/avast/retry-go/v4 v4.6.1
github.com/coreos/go-oidc/v3 v3.14.1
github.com/getsentry/sentry-go v0.34.0
github.com/getsentry/sentry-go v0.33.0
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
github.com/go-ldap/ldap/v3 v3.4.11
github.com/go-openapi/runtime v0.28.0
@ -18,18 +18,18 @@ require (
github.com/gorilla/sessions v1.4.0
github.com/gorilla/websocket v1.5.3
github.com/grafana/pyroscope-go v1.2.2
github.com/jellydator/ttlcache/v3 v3.4.0
github.com/jellydator/ttlcache/v3 v3.3.0
github.com/mitchellh/mapstructure v1.5.0
github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484
github.com/pires/go-proxyproto v0.8.1
github.com/prometheus/client_golang v1.22.0
github.com/redis/go-redis/v9 v9.11.0
github.com/redis/go-redis/v9 v9.10.0
github.com/sethvargo/go-envconfig v1.3.0
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.9.1
github.com/stretchr/testify v1.10.0
github.com/wwt/guac v1.3.2
goauthentik.io/api/v3 v3.2025062.6
goauthentik.io/api/v3 v3.2025061.2
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.15.0

16
go.sum
View File

@ -71,8 +71,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/getsentry/sentry-go v0.34.0 h1:1FCHBVp8TfSc8L10zqSwXUZNiOSF+10qw4czjarTiY4=
github.com/getsentry/sentry-go v0.34.0/go.mod h1:C55omcY9ChRQIUcVcGcs+Zdy4ZpQGvNJ7JYHIoSWOtE=
github.com/getsentry/sentry-go v0.33.0 h1:YWyDii0KGVov3xOaamOnF0mjOrqSjBqwv48UEzn7QFg=
github.com/getsentry/sentry-go v0.33.0/go.mod h1:C55omcY9ChRQIUcVcGcs+Zdy4ZpQGvNJ7JYHIoSWOtE=
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
@ -203,8 +203,8 @@ github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh6
github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs=
github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY=
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
github.com/jellydator/ttlcache/v3 v3.4.0 h1:YS4P125qQS0tNhtL6aeYkheEaB/m8HCqdMMP4mnWdTY=
github.com/jellydator/ttlcache/v3 v3.4.0/go.mod h1:Hw9EgjymziQD3yGsQdf1FqFdpp7YjFMd4Srg5EJlgD4=
github.com/jellydator/ttlcache/v3 v3.3.0 h1:BdoC9cE81qXfrxeb9eoJi9dWrdhSuwXMAnHTbnBm4Wc=
github.com/jellydator/ttlcache/v3 v3.3.0/go.mod h1:bj2/e0l4jRnQdrnSTaGTsh4GSXvMjQcy41i7th0GVGw=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
@ -251,8 +251,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/redis/go-redis/v9 v9.11.0 h1:E3S08Gl/nJNn5vkxd2i78wZxWAPNZgUNTp8WIJUAiIs=
github.com/redis/go-redis/v9 v9.11.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/redis/go-redis/v9 v9.10.0 h1:FxwK3eV8p/CQa0Ch276C7u2d0eNC9kCmAYQ7mCXCzVs=
github.com/redis/go-redis/v9 v9.10.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
@ -298,8 +298,8 @@ go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
goauthentik.io/api/v3 v3.2025062.6 h1:rlChhGP2vJufYCaTMb4sbRBEE1p2uL5T4HzMqF1AJ4A=
goauthentik.io/api/v3 v3.2025062.6/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
goauthentik.io/api/v3 v3.2025061.2 h1:bKmrl82Gz6J8lz3f+QIH9g+MEkl3MvkMXF34GktesA0=
goauthentik.io/api/v3 v3.2025061.2/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=

View File

@ -33,4 +33,4 @@ func UserAgent() string {
return fmt.Sprintf("authentik@%s", FullVersion())
}
const VERSION = "2025.6.2"
const VERSION = "2025.6.1"

Some files were not shown because too many files have changed in this diff Show More