Compare commits

..

1 Commits

Author SHA1 Message Date
faf8bf591f providers/saml: configuration for default NameID Policy
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-18 11:39:28 +02:00
318 changed files with 7415 additions and 6536 deletions

View File

@ -21,8 +21,6 @@ optional_value = final
[bumpversion:file:package.json]
[bumpversion:file:package-lock.json]
[bumpversion:file:docker-compose.yml]
[bumpversion:file:schema.yml]
@ -33,4 +31,6 @@ optional_value = final
[bumpversion:file:internal/constants/constants.go]
[bumpversion:file:web/src/common/constants.ts]
[bumpversion:file:lifecycle/aws/template.yaml]

View File

@ -7,9 +7,6 @@ charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[*.toml]
indent_size = 2
[*.html]
indent_size = 2

View File

@ -15,8 +15,8 @@ jobs:
matrix:
version:
- docs
- version-2025-4
- version-2025-2
- version-2024-12
steps:
- uses: actions/checkout@v4
- run: |

View File

@ -226,61 +226,6 @@ jobs:
flags: e2e
file: unittest.xml
token: ${{ secrets.CODECOV_TOKEN }}
test-conformance:
name: test-conformance (${{ matrix.job.name }})
runs-on: ubuntu-latest
timeout-minutes: 30
strategy:
fail-fast: false
matrix:
job:
- name: basic
glob: tests/openid_conformance/test_basic.py
- name: implicit
glob: tests/openid_conformance/test_implicit.py
steps:
- uses: actions/checkout@v4
- name: Setup authentik env
uses: ./.github/actions/setup
- name: Setup e2e env (chrome, etc)
run: |
docker compose -f tests/e2e/docker-compose.yml up -d --quiet-pull
- name: Setup conformance suite
run: |
docker compose -f tests/openid_conformance/compose.yml up -d --quiet-pull
- id: cache-web
uses: actions/cache@v4
with:
path: web/dist
key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b
- name: prepare web ui
if: steps.cache-web.outputs.cache-hit != 'true'
working-directory: web
run: |
npm ci
make -C .. gen-client-ts
npm run build
npm run build:sfe
- name: run conformance
run: |
uv run coverage run manage.py test ${{ matrix.job.glob }}
uv run coverage xml
- if: ${{ always() }}
uses: codecov/codecov-action@v5
with:
flags: conformance
token: ${{ secrets.CODECOV_TOKEN }}
- if: ${{ !cancelled() }}
uses: codecov/test-results-action@v1
with:
flags: conformance
file: unittest.xml
token: ${{ secrets.CODECOV_TOKEN }}
- if: ${{ !cancelled() }}
uses: actions/upload-artifact@v4
with:
name: conformance-certification-${{ matrix.job.glob }}
path: tests/openid_conformance/exports/
ci-core-mark:
if: always()
needs:
@ -290,7 +235,6 @@ jobs:
- test-unittest
- test-integration
- test-e2e
- test-conformance
runs-on: ubuntu-latest
steps:
- uses: re-actors/alls-green@release/v1

View File

@ -41,27 +41,6 @@ jobs:
- name: test
working-directory: website/
run: npm test
build:
runs-on: ubuntu-latest
name: ${{ matrix.job }}
strategy:
fail-fast: false
matrix:
job:
- build
- build:integrations
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version-file: website/package.json
cache: "npm"
cache-dependency-path: website/package-lock.json
- working-directory: website/
run: npm ci
- name: build
working-directory: website/
run: npm run ${{ matrix.job }}
build-container:
runs-on: ubuntu-latest
permissions:
@ -115,7 +94,6 @@ jobs:
needs:
- lint
- test
- build
- build-container
runs-on: ubuntu-latest
steps:

View File

@ -2,7 +2,7 @@ name: "CodeQL"
on:
push:
branches: [main, next, version*]
branches: [main, "*", next, version*]
pull_request:
branches: [main]
schedule:

1
.gitignore vendored
View File

@ -217,4 +217,3 @@ source_docs/
### Docker ###
docker-compose.override.yml
tests/openid_conformance/exports/*.zip

View File

@ -75,7 +75,7 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
/bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
# Stage 4: Download uv
FROM ghcr.io/astral-sh/uv:0.7.14 AS uv
FROM ghcr.io/astral-sh/uv:0.7.13 AS uv
# Stage 5: Base python image
FROM ghcr.io/goauthentik/fips-python:3.13.5-slim-bookworm-fips AS python-base

View File

@ -86,10 +86,6 @@ dev-create-db:
dev-reset: dev-drop-db dev-create-db migrate ## Drop and restore the Authentik PostgreSQL instance to a "fresh install" state.
update-test-mmdb: ## Update test GeoIP and ASN Databases
curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-ASN-Test.mmdb -o ${PWD}/tests/GeoLite2-ASN-Test.mmdb
curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-City-Test.mmdb -o ${PWD}/tests/GeoLite2-City-Test.mmdb
#########################
## API Schema
#########################

View File

@ -35,6 +35,6 @@ def blueprint_tester(file_name: Path) -> Callable:
for blueprint_file in Path("blueprints/").glob("**/*.yaml"):
if "local" in str(blueprint_file) or "testing" in str(blueprint_file):
if "local" in str(blueprint_file):
continue
setattr(TestPackaged, f"test_blueprint_{blueprint_file}", blueprint_tester(blueprint_file))

View File

@ -1,6 +1,8 @@
"""Authenticator Devices API Views"""
from drf_spectacular.utils import extend_schema
from django.utils.translation import gettext_lazy as _
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema
from guardian.shortcuts import get_objects_for_user
from rest_framework.fields import (
BooleanField,
@ -13,7 +15,6 @@ from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ViewSet
from authentik.core.api.users import ParamUserSerializer
from authentik.core.api.utils import MetaNameSerializer
from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import EndpointDevice
from authentik.stages.authenticator import device_classes, devices_for_user
@ -22,7 +23,7 @@ from authentik.stages.authenticator_webauthn.models import WebAuthnDevice
class DeviceSerializer(MetaNameSerializer):
"""Serializer for authenticator devices"""
"""Serializer for Duo authenticator devices"""
pk = CharField()
name = CharField()
@ -32,27 +33,22 @@ class DeviceSerializer(MetaNameSerializer):
last_updated = DateTimeField(read_only=True)
last_used = DateTimeField(read_only=True, allow_null=True)
extra_description = SerializerMethodField()
external_id = SerializerMethodField()
def get_type(self, instance: Device) -> str:
"""Get type of device"""
return instance._meta.label
def get_extra_description(self, instance: Device) -> str | None:
def get_extra_description(self, instance: Device) -> str:
"""Get extra description"""
if isinstance(instance, WebAuthnDevice):
return instance.device_type.description if instance.device_type else None
return (
instance.device_type.description
if instance.device_type
else _("Extra description not available")
)
if isinstance(instance, EndpointDevice):
return instance.data.get("deviceSignals", {}).get("deviceModel")
return None
def get_external_id(self, instance: Device) -> str | None:
"""Get external Device ID"""
if isinstance(instance, WebAuthnDevice):
return instance.device_type.aaguid if instance.device_type else None
if isinstance(instance, EndpointDevice):
return instance.data.get("deviceSignals", {}).get("deviceModel")
return None
return ""
class DeviceViewSet(ViewSet):
@ -61,6 +57,7 @@ class DeviceViewSet(ViewSet):
serializer_class = DeviceSerializer
permission_classes = [IsAuthenticated]
@extend_schema(responses={200: DeviceSerializer(many=True)})
def list(self, request: Request) -> Response:
"""Get all devices for current user"""
devices = devices_for_user(request.user)
@ -82,11 +79,18 @@ class AdminDeviceViewSet(ViewSet):
yield from device_set
@extend_schema(
parameters=[ParamUserSerializer],
parameters=[
OpenApiParameter(
name="user",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.INT,
)
],
responses={200: DeviceSerializer(many=True)},
)
def list(self, request: Request) -> Response:
"""Get all devices for current user"""
args = ParamUserSerializer(data=request.query_params)
args.is_valid(raise_exception=True)
return Response(DeviceSerializer(self.get_devices(**args.validated_data), many=True).data)
kwargs = {}
if "user" in request.query_params:
kwargs = {"user": request.query_params["user"]}
return Response(DeviceSerializer(self.get_devices(**kwargs), many=True).data)

View File

@ -90,12 +90,6 @@ from authentik.stages.email.utils import TemplateEmailMessage
LOGGER = get_logger()
class ParamUserSerializer(PassiveSerializer):
"""Partial serializer for query parameters to select a user"""
user = PrimaryKeyRelatedField(queryset=User.objects.all().exclude_anonymous(), required=False)
class UserGroupSerializer(ModelSerializer):
"""Simplified Group Serializer for user's groups"""
@ -392,23 +386,8 @@ class UserViewSet(UsedByMixin, ModelViewSet):
queryset = User.objects.none()
ordering = ["username"]
serializer_class = UserSerializer
filterset_class = UsersFilter
search_fields = ["username", "name", "is_active", "email", "uuid", "attributes"]
def get_ql_fields(self):
from djangoql.schema import BoolField, StrField
from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
return [
StrField(User, "username"),
StrField(User, "name"),
StrField(User, "email"),
StrField(User, "path"),
BoolField(User, "is_active", nullable=True),
ChoiceSearchField(User, "type"),
JSONSearchField(User, "attributes"),
]
filterset_class = UsersFilter
def get_queryset(self):
base_qs = User.objects.all().exclude_anonymous()

View File

@ -13,6 +13,7 @@ class Command(TenantCommand):
parser.add_argument("usernames", nargs="*", type=str)
def handle_per_tenant(self, **options):
print(options)
new_type = UserTypes(options["type"])
qs = (
User.objects.exclude_anonymous()

View File

@ -18,7 +18,7 @@ from django.http import HttpRequest
from django.utils.functional import SimpleLazyObject, cached_property
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from django_cte import CTE, with_cte
from django_cte import CTEQuerySet, With
from guardian.conf import settings
from guardian.mixins import GuardianUserMixin
from model_utils.managers import InheritanceManager
@ -136,7 +136,7 @@ class AttributesMixin(models.Model):
return instance, False
class GroupQuerySet(QuerySet):
class GroupQuerySet(CTEQuerySet):
def with_children_recursive(self):
"""Recursively get all groups that have the current queryset as parents
or are indirectly related."""
@ -165,9 +165,9 @@ class GroupQuerySet(QuerySet):
)
# Build the recursive query, see above
cte = CTE.recursive(make_cte)
cte = With.recursive(make_cte)
# Return the result, as a usable queryset for Group.
return with_cte(cte, select=cte.join(Group, group_uuid=cte.col.group_uuid))
return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte)
class Group(SerializerModel, AttributesMixin):

View File

@ -114,7 +114,6 @@ class TestApplicationsAPI(APITestCase):
self.assertJSONEqual(
response.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,
@ -168,7 +167,6 @@ class TestApplicationsAPI(APITestCase):
self.assertJSONEqual(
response.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,

View File

@ -119,17 +119,17 @@ class TestTrimPasswordHistory(TestCase):
[
UserPasswordHistory(
user=self.user,
old_password="hunter1", # nosec
old_password="hunter1", # nosec B106
created_at=_now - timedelta(days=3),
),
UserPasswordHistory(
user=self.user,
old_password="hunter2", # nosec
old_password="hunter2", # nosec B106
created_at=_now - timedelta(days=2),
),
UserPasswordHistory(
user=self.user,
old_password="hunter3", # nosec
old_password="hunter3", # nosec B106
created_at=_now,
),
]

View File

@ -1,8 +1,10 @@
from hashlib import sha256
from django.contrib.auth.signals import user_logged_out
from django.db.models import Model
from django.db.models.signals import post_delete, post_save, pre_delete
from django.dispatch import receiver
from django.http.request import HttpRequest
from guardian.shortcuts import assign_perm
from authentik.core.models import (
@ -60,6 +62,31 @@ def ssf_providers_post_save(sender: type[Model], instance: SSFProvider, created:
instance.save()
@receiver(user_logged_out)
def ssf_user_logged_out_session_revoked(sender, request: HttpRequest, user: User, **_):
"""Session revoked trigger (user logged out)"""
if not request.session or not request.session.session_key or not user:
return
send_ssf_event(
EventTypes.CAEP_SESSION_REVOKED,
{
"initiating_entity": "user",
},
sub_id={
"format": "complex",
"session": {
"format": "opaque",
"id": sha256(request.session.session_key.encode("ascii")).hexdigest(),
},
"user": {
"format": "email",
"email": user.email,
},
},
request=request,
)
@receiver(pre_delete, sender=AuthenticatedSession)
def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSession, **_):
"""Session revoked trigger (users' session has been deleted)

View File

@ -1,12 +0,0 @@
"""Enterprise app config"""
from authentik.enterprise.apps import EnterpriseConfig
class AuthentikEnterpriseSearchConfig(EnterpriseConfig):
"""Enterprise app config"""
name = "authentik.enterprise.search"
label = "authentik_search"
verbose_name = "authentik Enterprise.Search"
default = True

View File

@ -1,128 +0,0 @@
"""DjangoQL search"""
from collections import OrderedDict, defaultdict
from collections.abc import Generator
from django.db import connection
from django.db.models import Model, Q
from djangoql.compat import text_type
from djangoql.schema import StrField
class JSONSearchField(StrField):
"""JSON field for DjangoQL"""
model: Model
def __init__(self, model=None, name=None, nullable=None, suggest_nested=True):
# Set this in the constructor to not clobber the type variable
self.type = "relation"
self.suggest_nested = suggest_nested
super().__init__(model, name, nullable)
def get_lookup(self, path, operator, value):
search = "__".join(path)
op, invert = self.get_operator(operator)
q = Q(**{f"{search}{op}": self.get_lookup_value(value)})
return ~q if invert else q
def json_field_keys(self) -> Generator[tuple[str]]:
with connection.cursor() as cursor:
cursor.execute(
f"""
WITH RECURSIVE "{self.name}_keys" AS (
SELECT
ARRAY[jsonb_object_keys("{self.name}")] AS key_path_array,
"{self.name}" -> jsonb_object_keys("{self.name}") AS value
FROM {self.model._meta.db_table}
WHERE "{self.name}" IS NOT NULL
AND jsonb_typeof("{self.name}") = 'object'
UNION ALL
SELECT
ck.key_path_array || jsonb_object_keys(ck.value),
ck.value -> jsonb_object_keys(ck.value) AS value
FROM "{self.name}_keys" ck
WHERE jsonb_typeof(ck.value) = 'object'
),
unique_paths AS (
SELECT DISTINCT key_path_array
FROM "{self.name}_keys"
)
SELECT key_path_array FROM unique_paths;
""" # nosec
)
return (x[0] for x in cursor.fetchall())
def get_nested_options(self) -> OrderedDict:
"""Get keys of all nested objects to show autocomplete"""
if not self.suggest_nested:
return OrderedDict()
base_model_name = f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}"
def recursive_function(parts: list[str], parent_parts: list[str] | None = None):
if not parent_parts:
parent_parts = []
path = parts.pop(0)
parent_parts.append(path)
relation_key = "_".join(parent_parts)
if len(parts) > 1:
out_dict = {
relation_key: {
parts[0]: {
"type": "relation",
"relation": f"{relation_key}_{parts[0]}",
}
}
}
child_paths = recursive_function(parts.copy(), parent_parts.copy())
child_paths.update(out_dict)
return child_paths
else:
return {relation_key: {parts[0]: {}}}
relation_structure = defaultdict(dict)
for relations in self.json_field_keys():
result = recursive_function([base_model_name] + relations)
for relation_key, value in result.items():
for sub_relation_key, sub_value in value.items():
if not relation_structure[relation_key].get(sub_relation_key, None):
relation_structure[relation_key][sub_relation_key] = sub_value
else:
relation_structure[relation_key][sub_relation_key].update(sub_value)
final_dict = defaultdict(dict)
for key, value in relation_structure.items():
for sub_key, sub_value in value.items():
if not sub_value:
final_dict[key][sub_key] = {
"type": "str",
"nullable": True,
}
else:
final_dict[key][sub_key] = sub_value
return OrderedDict(final_dict)
def relation(self) -> str:
return f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}"
class ChoiceSearchField(StrField):
def __init__(self, model=None, name=None, nullable=None):
super().__init__(model, name, nullable, suggest_options=True)
def get_options(self, search):
result = []
choices = self._field_choices()
if choices:
search = search.lower()
for c in choices:
choice = text_type(c[0])
if search in choice.lower():
result.append(choice)
return result

View File

@ -1,53 +0,0 @@
from rest_framework.response import Response
from authentik.api.pagination import Pagination
from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, QLSearch
class AutocompletePagination(Pagination):
def paginate_queryset(self, queryset, request, view=None):
self.view = view
return super().paginate_queryset(queryset, request, view)
def get_autocomplete(self):
schema = QLSearch().get_schema(self.request, self.view)
introspections = {}
if hasattr(self.view, "get_ql_fields"):
from authentik.enterprise.search.schema import AKQLSchemaSerializer
introspections = AKQLSchemaSerializer().serialize(
schema(self.page.paginator.object_list.model)
)
return introspections
def get_paginated_response(self, data):
previous_page_number = 0
if self.page.has_previous():
previous_page_number = self.page.previous_page_number()
next_page_number = 0
if self.page.has_next():
next_page_number = self.page.next_page_number()
return Response(
{
"pagination": {
"next": next_page_number,
"previous": previous_page_number,
"count": self.page.paginator.count,
"current": self.page.number,
"total_pages": self.page.paginator.num_pages,
"start_index": self.page.start_index(),
"end_index": self.page.end_index(),
},
"results": data,
"autocomplete": self.get_autocomplete(),
}
)
def get_paginated_response_schema(self, schema):
final_schema = super().get_paginated_response_schema(schema)
final_schema["properties"]["autocomplete"] = {
"$ref": f"#/components/schemas/{AUTOCOMPLETE_COMPONENT_NAME}"
}
final_schema["required"].append("autocomplete")
return final_schema

View File

@ -1,78 +0,0 @@
"""DjangoQL search"""
from django.apps import apps
from django.db.models import QuerySet
from djangoql.ast import Name
from djangoql.exceptions import DjangoQLError
from djangoql.queryset import apply_search
from djangoql.schema import DjangoQLSchema
from rest_framework.filters import SearchFilter
from rest_framework.request import Request
from structlog.stdlib import get_logger
from authentik.enterprise.search.fields import JSONSearchField
LOGGER = get_logger()
AUTOCOMPLETE_COMPONENT_NAME = "Autocomplete"
AUTOCOMPLETE_SCHEMA = {
"type": "object",
"additionalProperties": {},
}
class BaseSchema(DjangoQLSchema):
"""Base Schema which deals with JSON Fields"""
def resolve_name(self, name: Name):
model = self.model_label(self.current_model)
root_field = name.parts[0]
field = self.models[model].get(root_field)
# If the query goes into a JSON field, return the root
# field as the JSON field will do the rest
if isinstance(field, JSONSearchField):
# This is a workaround; build_filter will remove the right-most
# entry in the path as that is intended to be the same as the field
# however for JSON that is not the case
if name.parts[-1] != root_field:
name.parts.append(root_field)
return field
return super().resolve_name(name)
class QLSearch(SearchFilter):
"""rest_framework search filter which uses DjangoQL"""
@property
def enabled(self):
return apps.get_app_config("authentik_enterprise").enabled()
def get_search_terms(self, request) -> str:
"""
Search terms are set by a ?search=... query parameter,
and may be comma and/or whitespace delimited.
"""
params = request.query_params.get(self.search_param, "")
params = params.replace("\x00", "") # strip null characters
return params
def get_schema(self, request: Request, view) -> BaseSchema:
ql_fields = []
if hasattr(view, "get_ql_fields"):
ql_fields = view.get_ql_fields()
class InlineSchema(BaseSchema):
def get_fields(self, model):
return ql_fields or []
return InlineSchema
def filter_queryset(self, request: Request, queryset: QuerySet, view) -> QuerySet:
search_query = self.get_search_terms(request)
schema = self.get_schema(request, view)
if len(search_query) == 0 or not self.enabled:
return super().filter_queryset(request, queryset, view)
try:
return apply_search(queryset, search_query, schema=schema)
except DjangoQLError as exc:
LOGGER.debug("Failed to parse search expression", exc=exc)
return super().filter_queryset(request, queryset, view)

View File

@ -1,29 +0,0 @@
from djangoql.serializers import DjangoQLSchemaSerializer
from drf_spectacular.generators import SchemaGenerator
from authentik.api.schema import create_component
from authentik.enterprise.search.fields import JSONSearchField
from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA
class AKQLSchemaSerializer(DjangoQLSchemaSerializer):
def serialize(self, schema):
serialization = super().serialize(schema)
for _, fields in schema.models.items():
for _, field in fields.items():
if not isinstance(field, JSONSearchField):
continue
serialization["models"].update(field.get_nested_options())
return serialization
def serialize_field(self, field):
result = super().serialize_field(field)
if isinstance(field, JSONSearchField):
result["relation"] = field.relation()
return result
def postprocess_schema_search_autocomplete(result, generator: SchemaGenerator, **kwargs):
create_component(generator, AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA)
return result

View File

@ -1,17 +0,0 @@
SPECTACULAR_SETTINGS = {
"POSTPROCESSING_HOOKS": [
"authentik.api.schema.postprocess_schema_responses",
"authentik.enterprise.search.schema.postprocess_schema_search_autocomplete",
"drf_spectacular.hooks.postprocess_schema_enums",
],
}
REST_FRAMEWORK = {
"DEFAULT_PAGINATION_CLASS": "authentik.enterprise.search.pagination.AutocompletePagination",
"DEFAULT_FILTER_BACKENDS": [
"authentik.enterprise.search.ql.QLSearch",
"authentik.rbac.filters.ObjectFilter",
"django_filters.rest_framework.DjangoFilterBackend",
"rest_framework.filters.OrderingFilter",
],
}

View File

@ -1,78 +0,0 @@
from json import loads
from unittest.mock import PropertyMock, patch
from urllib.parse import urlencode
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.tests.utils import create_test_admin_user
@patch(
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
PropertyMock(return_value=True),
)
class QLTest(APITestCase):
def setUp(self):
self.user = create_test_admin_user()
# ensure we have more than 1 user
create_test_admin_user()
def test_search(self):
"""Test simple search query"""
self.client.force_login(self.user)
query = f'username = "{self.user.username}"'
res = self.client.get(
reverse(
"authentik_api:user-list",
)
+ f"?{urlencode({"search": query})}"
)
self.assertEqual(res.status_code, 200)
content = loads(res.content)
self.assertEqual(content["pagination"]["count"], 1)
self.assertEqual(content["results"][0]["username"], self.user.username)
def test_no_search(self):
"""Ensure works with no search query"""
self.client.force_login(self.user)
res = self.client.get(
reverse(
"authentik_api:user-list",
)
)
self.assertEqual(res.status_code, 200)
content = loads(res.content)
self.assertNotEqual(content["pagination"]["count"], 1)
def test_search_no_ql(self):
"""Test simple search query (no QL)"""
self.client.force_login(self.user)
res = self.client.get(
reverse(
"authentik_api:user-list",
)
+ f"?{urlencode({"search": self.user.username})}"
)
self.assertEqual(res.status_code, 200)
content = loads(res.content)
self.assertGreaterEqual(content["pagination"]["count"], 1)
self.assertEqual(content["results"][0]["username"], self.user.username)
def test_search_json(self):
"""Test search query with a JSON attribute"""
self.user.attributes = {"foo": {"bar": "baz"}}
self.user.save()
self.client.force_login(self.user)
query = 'attributes.foo.bar = "baz"'
res = self.client.get(
reverse(
"authentik_api:user-list",
)
+ f"?{urlencode({"search": query})}"
)
self.assertEqual(res.status_code, 200)
content = loads(res.content)
self.assertEqual(content["pagination"]["count"], 1)
self.assertEqual(content["results"][0]["username"], self.user.username)

View File

@ -18,7 +18,6 @@ TENANT_APPS = [
"authentik.enterprise.providers.google_workspace",
"authentik.enterprise.providers.microsoft_entra",
"authentik.enterprise.providers.ssf",
"authentik.enterprise.search",
"authentik.enterprise.stages.authenticator_endpoint_gdtc",
"authentik.enterprise.stages.mtls",
"authentik.enterprise.stages.source",

View File

@ -97,7 +97,6 @@ class SourceStageFinal(StageView):
token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN)
self.logger.info("Replacing source flow with overridden flow", flow=token.flow.slug)
plan = token.plan
plan.context.update(self.executor.plan.context)
plan.context[PLAN_CONTEXT_IS_RESTORED] = token
response = plan.to_redirect(self.request, token.flow)
token.delete()

View File

@ -90,17 +90,14 @@ class TestSourceStage(FlowTestCase):
plan: FlowPlan = session[SESSION_KEY_PLAN]
plan.insert_stage(in_memory_stage(SourceStageFinal), index=0)
plan.context[PLAN_CONTEXT_IS_RESTORED] = flow_token
plan.context["foo"] = "bar"
session[SESSION_KEY_PLAN] = plan
session.save()
# Pretend we've just returned from the source
with self.assertFlowFinishes() as ff:
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True
)
self.assertEqual(response.status_code, 200)
self.assertStageRedirects(
response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
)
self.assertEqual(ff().context["foo"], "bar")
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True
)
self.assertEqual(response.status_code, 200)
self.assertStageRedirects(
response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
)

View File

@ -132,22 +132,6 @@ class EventViewSet(ModelViewSet):
]
filterset_class = EventsFilter
def get_ql_fields(self):
from djangoql.schema import DateTimeField, StrField
from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
return [
ChoiceSearchField(Event, "action"),
StrField(Event, "event_uuid"),
StrField(Event, "app", suggest_options=True),
StrField(Event, "client_ip"),
JSONSearchField(Event, "user", suggest_nested=False),
JSONSearchField(Event, "brand", suggest_nested=False),
JSONSearchField(Event, "context", suggest_nested=False),
DateTimeField(Event, "created", suggest_options=True),
]
@extend_schema(
methods=["GET"],
responses={200: EventTopPerUserSerializer(many=True)},

View File

@ -11,7 +11,7 @@ from authentik.events.models import NotificationRule
class NotificationRuleSerializer(ModelSerializer):
"""NotificationRule Serializer"""
destination_group_obj = GroupSerializer(read_only=True, source="destination_group")
group_obj = GroupSerializer(read_only=True, source="group")
class Meta:
model = NotificationRule
@ -20,9 +20,8 @@ class NotificationRuleSerializer(ModelSerializer):
"name",
"transports",
"severity",
"destination_group",
"destination_group_obj",
"destination_event_user",
"group",
"group_obj",
]
@ -31,6 +30,6 @@ class NotificationRuleViewSet(UsedByMixin, ModelViewSet):
queryset = NotificationRule.objects.all()
serializer_class = NotificationRuleSerializer
filterset_fields = ["name", "severity", "destination_group__name"]
filterset_fields = ["name", "severity", "group__name"]
ordering = ["name"]
search_fields = ["name", "destination_group__name"]
search_fields = ["name", "group__name"]

View File

@ -15,13 +15,13 @@ class MMDBContextProcessor(EventContextProcessor):
self.reader: Reader | None = None
self._last_mtime: float = 0.0
self.logger = get_logger()
self.load()
self.open()
def path(self) -> str | None:
"""Get the path to the MMDB file to load"""
raise NotImplementedError
def load(self):
def open(self):
"""Get GeoIP Reader, if configured, otherwise none"""
path = self.path()
if path == "" or not path:
@ -44,7 +44,7 @@ class MMDBContextProcessor(EventContextProcessor):
diff = self._last_mtime < mtime
if diff > 0:
self.logger.info("Found new MMDB Database, reopening", diff=diff, path=path)
self.load()
self.open()
except OSError as exc:
self.logger.warning("Failed to check MMDB age", exc=exc)

View File

@ -19,7 +19,7 @@ from authentik.blueprints.v1.importer import excluded_models
from authentik.core.models import Group, User
from authentik.events.models import Event, EventAction, Notification
from authentik.events.utils import model_to_dict
from authentik.lib.sentry import should_ignore_exception
from authentik.lib.sentry import before_send
from authentik.lib.utils.errors import exception_to_string
from authentik.stages.authenticator_static.models import StaticToken
@ -173,7 +173,7 @@ class AuditMiddleware:
message=exception_to_string(exception),
)
thread.run()
elif not should_ignore_exception(exception):
elif before_send({}, {"exc_info": (None, exception, None)}) is not None:
thread = EventNewThread(
EventAction.SYSTEM_EXCEPTION,
request,

View File

@ -1,26 +0,0 @@
# Generated by Django 5.1.11 on 2025-06-16 23:21
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_events", "0009_remove_notificationtransport_webhook_mapping_and_more"),
]
operations = [
migrations.RenameField(
model_name="notificationrule",
old_name="group",
new_name="destination_group",
),
migrations.AddField(
model_name="notificationrule",
name="destination_event_user",
field=models.BooleanField(
default=False,
help_text="When enabled, notification will be sent to user the user that triggered the event.When destination_group is configured, notification is sent to both.",
),
),
]

View File

@ -1,12 +1,10 @@
"""authentik events models"""
from collections.abc import Generator
from datetime import timedelta
from difflib import get_close_matches
from functools import lru_cache
from inspect import currentframe
from smtplib import SMTPException
from typing import Any
from uuid import uuid4
from django.apps import apps
@ -549,7 +547,7 @@ class NotificationRule(SerializerModel, PolicyBindingModel):
default=NotificationSeverity.NOTICE,
help_text=_("Controls which severity level the created notifications will have."),
)
destination_group = models.ForeignKey(
group = models.ForeignKey(
Group,
help_text=_(
"Define which group of users this notification should be sent and shown to. "
@ -559,19 +557,6 @@ class NotificationRule(SerializerModel, PolicyBindingModel):
blank=True,
on_delete=models.SET_NULL,
)
destination_event_user = models.BooleanField(
default=False,
help_text=_(
"When enabled, notification will be sent to user the user that triggered the event."
"When destination_group is configured, notification is sent to both."
),
)
def destination_users(self, event: Event) -> Generator[User, Any]:
if self.destination_event_user and event.user.get("pk"):
yield User(pk=event.user.get("pk"))
if self.destination_group:
yield from self.destination_group.users.all()
@property
def serializer(self) -> type[Serializer]:

View File

@ -68,10 +68,14 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
if not result.passing:
return
if not trigger.group:
LOGGER.debug("e(trigger): trigger has no group", trigger=trigger)
return
LOGGER.debug("e(trigger): event trigger matched", trigger=trigger)
# Create the notification objects
for transport in trigger.transports.all():
for user in trigger.destination_users(event):
for user in trigger.group.users.all():
LOGGER.debug("created notification")
notification_transport.apply_async(
args=[

View File

@ -2,9 +2,7 @@
from django.test import TestCase
from authentik.events.context_processors.base import get_context_processors
from authentik.events.context_processors.geoip import GeoIPContextProcessor
from authentik.events.models import Event, EventAction
class TestGeoIP(TestCase):
@ -15,7 +13,8 @@ class TestGeoIP(TestCase):
def test_simple(self):
"""Test simple city wrapper"""
# IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
# IPs from
# https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
self.assertEqual(
self.reader.city_dict("2.125.160.216"),
{
@ -26,12 +25,3 @@ class TestGeoIP(TestCase):
"long": -1.25,
},
)
def test_special_chars(self):
"""Test city name with special characters"""
# IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
event = Event.new(EventAction.LOGIN)
event.client_ip = "89.160.20.112"
for processor in get_context_processors():
processor.enrich_event(event)
event.save()

View File

@ -6,7 +6,6 @@ from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_user
from authentik.events.models import (
Event,
EventAction,
@ -35,7 +34,7 @@ class TestEventsNotifications(APITestCase):
def test_trigger_empty(self):
"""Test trigger without any policies attached"""
transport = NotificationTransport.objects.create(name=generate_id())
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport)
trigger.save()
@ -47,7 +46,7 @@ class TestEventsNotifications(APITestCase):
def test_trigger_single(self):
"""Test simple transport triggering"""
transport = NotificationTransport.objects.create(name=generate_id())
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport)
trigger.save()
matcher = EventMatcherPolicy.objects.create(
@ -60,25 +59,6 @@ class TestEventsNotifications(APITestCase):
Event.new(EventAction.CUSTOM_PREFIX).save()
self.assertEqual(execute_mock.call_count, 1)
def test_trigger_event_user(self):
"""Test trigger with event user"""
user = create_test_user()
transport = NotificationTransport.objects.create(name=generate_id())
trigger = NotificationRule.objects.create(name=generate_id(), destination_event_user=True)
trigger.transports.add(transport)
trigger.save()
matcher = EventMatcherPolicy.objects.create(
name="matcher", action=EventAction.CUSTOM_PREFIX
)
PolicyBinding.objects.create(target=trigger, policy=matcher, order=0)
execute_mock = MagicMock()
with patch("authentik.events.models.NotificationTransport.send", execute_mock):
Event.new(EventAction.CUSTOM_PREFIX).set_user(user).save()
self.assertEqual(execute_mock.call_count, 1)
notification: Notification = execute_mock.call_args[0][0]
self.assertEqual(notification.user, user)
def test_trigger_no_group(self):
"""Test trigger without group"""
trigger = NotificationRule.objects.create(name=generate_id())
@ -96,7 +76,7 @@ class TestEventsNotifications(APITestCase):
"""Test Policy error which would cause recursion"""
transport = NotificationTransport.objects.create(name=generate_id())
NotificationRule.objects.filter(name__startswith="default").delete()
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport)
trigger.save()
matcher = EventMatcherPolicy.objects.create(
@ -119,7 +99,7 @@ class TestEventsNotifications(APITestCase):
transport = NotificationTransport.objects.create(name=generate_id(), send_once=True)
NotificationRule.objects.filter(name__startswith="default").delete()
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport)
trigger.save()
matcher = EventMatcherPolicy.objects.create(
@ -143,7 +123,7 @@ class TestEventsNotifications(APITestCase):
name=generate_id(), webhook_mapping_body=mapping, mode=TransportMode.LOCAL
)
NotificationRule.objects.filter(name__startswith="default").delete()
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport)
matcher = EventMatcherPolicy.objects.create(
name="matcher", action=EventAction.CUSTOM_PREFIX

View File

@ -4,10 +4,8 @@ from unittest.mock import MagicMock, PropertyMock, patch
from urllib.parse import urlencode
from django.http import HttpRequest, HttpResponse
from django.test import override_settings
from django.test.client import RequestFactory
from django.urls import reverse
from rest_framework.exceptions import ParseError
from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_flow, create_test_user
@ -650,25 +648,3 @@ class TestFlowExecutor(FlowTestCase):
self.assertStageResponse(response, flow, component="ak-stage-identification")
response = self.client.post(exec_url, {"uid_field": user_other.username}, follow=True)
self.assertStageResponse(response, flow, component="ak-stage-access-denied")
@patch(
"authentik.flows.views.executor.to_stage_response",
TO_STAGE_RESPONSE_MOCK,
)
def test_invalid_json(self):
"""Test invalid JSON body"""
flow = create_test_flow()
FlowStageBinding.objects.create(
target=flow, stage=DummyStage.objects.create(name=generate_id()), order=0
)
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
with override_settings(TEST=False, DEBUG=False):
self.client.logout()
response = self.client.post(url, data="{", content_type="application/json")
self.assertEqual(response.status_code, 200)
with self.assertRaises(ParseError):
self.client.logout()
response = self.client.post(url, data="{", content_type="application/json")
self.assertEqual(response.status_code, 200)

View File

@ -55,7 +55,7 @@ from authentik.flows.planner import (
FlowPlanner,
)
from authentik.flows.stage import AccessDeniedStage, StageView
from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception
from authentik.lib.sentry import SentryIgnoredException
from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.reflection import all_subclasses, class_to_path
from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs
@ -234,13 +234,12 @@ class FlowExecutorView(APIView):
"""Handle exception in stage execution"""
if settings.DEBUG or settings.TEST:
raise exc
capture_exception(exc)
self._logger.warning(exc)
if not should_ignore_exception(exc):
capture_exception(exc)
Event.new(
action=EventAction.SYSTEM_EXCEPTION,
message=exception_to_string(exc),
).from_http(self.request)
Event.new(
action=EventAction.SYSTEM_EXCEPTION,
message=exception_to_string(exc),
).from_http(self.request)
challenge = FlowErrorChallenge(self.request, exc)
challenge.is_valid(raise_exception=True)
return to_stage_response(self.request, HttpChallengeResponse(challenge))

View File

@ -104,7 +104,6 @@ def get_logger_config():
"hpack": "WARNING",
"httpx": "WARNING",
"azure": "WARNING",
"httpcore": "WARNING",
}
for handler_name, level in handler_level_map.items():
base_config["loggers"][handler_name] = {

View File

@ -14,7 +14,6 @@ from django_redis.exceptions import ConnectionInterrupted
from docker.errors import DockerException
from h11 import LocalProtocolError
from ldap3.core.exceptions import LDAPException
from psycopg.errors import Error
from redis.exceptions import ConnectionError as RedisConnectionError
from redis.exceptions import RedisError, ResponseError
from rest_framework.exceptions import APIException
@ -45,49 +44,6 @@ class SentryIgnoredException(Exception):
"""Base Class for all errors that are suppressed, and not sent to sentry."""
ignored_classes = (
# Inbuilt types
KeyboardInterrupt,
ConnectionResetError,
OSError,
PermissionError,
# Django Errors
Error,
ImproperlyConfigured,
DatabaseError,
OperationalError,
InternalError,
ProgrammingError,
SuspiciousOperation,
ValidationError,
# Redis errors
RedisConnectionError,
ConnectionInterrupted,
RedisError,
ResponseError,
# websocket errors
ChannelFull,
WebSocketException,
LocalProtocolError,
# rest_framework error
APIException,
# celery errors
WorkerLostError,
CeleryError,
SoftTimeLimitExceeded,
# custom baseclass
SentryIgnoredException,
# ldap errors
LDAPException,
# Docker errors
DockerException,
# End-user errors
Http404,
# AsyncIO
CancelledError,
)
class SentryTransport(HttpTransport):
"""Custom sentry transport with custom user-agent"""
@ -145,17 +101,56 @@ def traces_sampler(sampling_context: dict) -> float:
return float(CONFIG.get("error_reporting.sample_rate", 0.1))
def should_ignore_exception(exc: Exception) -> bool:
"""Check if an exception should be dropped"""
return isinstance(exc, ignored_classes)
def before_send(event: dict, hint: dict) -> dict | None:
"""Check if error is database error, and ignore if so"""
from psycopg.errors import Error
ignored_classes = (
# Inbuilt types
KeyboardInterrupt,
ConnectionResetError,
OSError,
PermissionError,
# Django Errors
Error,
ImproperlyConfigured,
DatabaseError,
OperationalError,
InternalError,
ProgrammingError,
SuspiciousOperation,
ValidationError,
# Redis errors
RedisConnectionError,
ConnectionInterrupted,
RedisError,
ResponseError,
# websocket errors
ChannelFull,
WebSocketException,
LocalProtocolError,
# rest_framework error
APIException,
# celery errors
WorkerLostError,
CeleryError,
SoftTimeLimitExceeded,
# custom baseclass
SentryIgnoredException,
# ldap errors
LDAPException,
# Docker errors
DockerException,
# End-user errors
Http404,
# AsyncIO
CancelledError,
)
exc_value = None
if "exc_info" in hint:
_, exc_value, _ = hint["exc_info"]
if should_ignore_exception(exc_value):
if isinstance(exc_value, ignored_classes):
LOGGER.debug("dropping exception", exc=exc_value)
return None
if "logger" in event:

View File

@ -2,7 +2,7 @@
from django.test import TestCase
from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception
from authentik.lib.sentry import SentryIgnoredException, before_send
class TestSentry(TestCase):
@ -10,8 +10,8 @@ class TestSentry(TestCase):
def test_error_not_sent(self):
"""Test SentryIgnoredError not sent"""
self.assertTrue(should_ignore_exception(SentryIgnoredException()))
self.assertIsNone(before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)}))
def test_error_sent(self):
"""Test error sent"""
self.assertFalse(should_ignore_exception(ValueError()))
self.assertEqual({}, before_send({}, {"exc_info": (0, ValueError(), 0)}))

View File

@ -1,13 +1,15 @@
"""authentik outpost signals"""
from django.contrib.auth.signals import user_logged_out
from django.core.cache import cache
from django.db.models import Model
from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save
from django.dispatch import receiver
from django.http import HttpRequest
from structlog.stdlib import get_logger
from authentik.brands.models import Brand
from authentik.core.models import AuthenticatedSession, Provider
from authentik.core.models import AuthenticatedSession, Provider, User
from authentik.crypto.models import CertificateKeyPair
from authentik.lib.utils.reflection import class_to_path
from authentik.outposts.models import Outpost, OutpostServiceConnection
@ -80,6 +82,14 @@ def pre_delete_cleanup(sender, instance: Outpost, **_):
outpost_controller.delay(instance.pk.hex, action="down", from_cache=True)
@receiver(user_logged_out)
def logout_revoke_direct(sender: type[User], request: HttpRequest, **_):
"""Catch logout by direct logout and forward to providers"""
if not request.session or not request.session.session_key:
return
outpost_session_end.delay(request.session.session_key)
@receiver(pre_delete, sender=AuthenticatedSession)
def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
"""Catch logout by expiring sessions being deleted"""

View File

@ -1,10 +1,23 @@
from django.contrib.auth.signals import user_logged_out
from django.db.models.signals import post_save, pre_delete
from django.dispatch import receiver
from django.http import HttpRequest
from authentik.core.models import AuthenticatedSession, User
from authentik.providers.oauth2.models import AccessToken, DeviceToken, RefreshToken
@receiver(user_logged_out)
def user_logged_out_oauth_tokens_removal(sender, request: HttpRequest, user: User, **_):
"""Revoke tokens upon user logout"""
if not request.session or not request.session.session_key:
return
AccessToken.objects.filter(
user=user,
session__session__session_key=request.session.session_key,
).delete()
@receiver(pre_delete, sender=AuthenticatedSession)
def user_session_deleted_oauth_tokens_removal(sender, instance: AuthenticatedSession, **_):
"""Revoke tokens upon user logout"""

View File

@ -2,11 +2,13 @@
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.contrib.auth.signals import user_logged_out
from django.core.cache import cache
from django.db.models.signals import post_delete, post_save, pre_delete
from django.dispatch import receiver
from django.http import HttpRequest
from authentik.core.models import AuthenticatedSession
from authentik.core.models import AuthenticatedSession, User
from authentik.providers.rac.api.endpoints import user_endpoint_cache_key
from authentik.providers.rac.consumer_client import (
RAC_CLIENT_GROUP_SESSION,
@ -15,6 +17,21 @@ from authentik.providers.rac.consumer_client import (
from authentik.providers.rac.models import ConnectionToken, Endpoint
@receiver(user_logged_out)
def user_logged_out_session(sender, request: HttpRequest, user: User, **_):
"""Disconnect any open RAC connections"""
if not request.session or not request.session.session_key:
return
layer = get_channel_layer()
async_to_sync(layer.group_send)(
RAC_CLIENT_GROUP_SESSION
% {
"session": request.session.session_key,
},
{"type": "event.disconnect", "reason": "session_logout"},
)
@receiver(pre_delete, sender=AuthenticatedSession)
def user_session_deleted(sender, instance: AuthenticatedSession, **_):
layer = get_channel_layer()

View File

@ -49,7 +49,6 @@ class TestEndpointsAPI(APITestCase):
self.assertJSONEqual(
response.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,
@ -102,7 +101,6 @@ class TestEndpointsAPI(APITestCase):
self.assertJSONEqual(
response.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,

View File

@ -190,6 +190,7 @@ class SAMLProviderSerializer(ProviderSerializer):
"sign_response",
"sp_binding",
"default_relay_state",
"default_name_id_policy",
"url_download_metadata",
"url_sso_post",
"url_sso_redirect",

View File

@ -0,0 +1,31 @@
# Generated by Django 5.1.11 on 2025-06-18 09:27
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_saml", "0018_alter_samlprovider_acs_url"),
]
operations = [
migrations.AddField(
model_name="samlprovider",
name="default_name_id_policy",
field=models.TextField(
choices=[
("urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress", "Email"),
("urn:oasis:names:tc:SAML:2.0:nameid-format:persistent", "Persistent"),
("urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName", "X509"),
(
"urn:oasis:names:tc:SAML:2.0:nameid-format:WindowsDomainQualifiedName",
"Windows",
),
("urn:oasis:names:tc:SAML:2.0:nameid-format:transient", "Transient"),
("urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified", "Unspecified"),
],
default="urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified",
),
),
]

View File

@ -12,6 +12,7 @@ from authentik.core.models import PropertyMapping, Provider
from authentik.crypto.models import CertificateKeyPair
from authentik.lib.models import DomainlessURLValidator
from authentik.lib.utils.time import timedelta_string_validator
from authentik.sources.saml.models import SAMLNameIDPolicy
from authentik.sources.saml.processors.constants import (
DSA_SHA1,
ECDSA_SHA1,
@ -179,6 +180,9 @@ class SAMLProvider(Provider):
default_relay_state = models.TextField(
default="", blank=True, help_text=_("Default relay_state value for IDP-initiated logins")
)
default_name_id_policy = models.TextField(
choices=SAMLNameIDPolicy.choices, default=SAMLNameIDPolicy.UNSPECIFIED
)
sign_assertion = models.BooleanField(default=True)
sign_response = models.BooleanField(default=False)

View File

@ -205,6 +205,13 @@ class AssertionProcessor:
def get_name_id(self) -> Element:
"""Get NameID Element"""
name_id = Element(f"{{{NS_SAML_ASSERTION}}}NameID")
# For requests that don't specify a NameIDPolicy, check if we
# can fall back to the provider default
if (
self.auth_n_request.name_id_policy == SAML_NAME_ID_FORMAT_UNSPECIFIED
and self.provider.default_name_id_policy != SAML_NAME_ID_FORMAT_UNSPECIFIED
):
self.auth_n_request.name_id_policy = self.provider.default_name_id_policy
name_id.attrib["Format"] = self.auth_n_request.name_id_policy
# persistent is used as a fallback, so always generate it
persistent = self.http_request.user.uid

View File

@ -13,6 +13,7 @@ from authentik.lib.xml import lxml_from_string
from authentik.providers.saml.exceptions import CannotHandleAssertion
from authentik.providers.saml.models import SAMLProvider
from authentik.providers.saml.utils.encoding import decode_base64_and_inflate
from authentik.sources.saml.models import SAMLNameIDPolicy
from authentik.sources.saml.processors.constants import (
DSA_SHA1,
NS_MAP,
@ -175,7 +176,9 @@ class AuthNRequestParser:
def idp_initiated(self) -> AuthNRequest:
"""Create IdP Initiated AuthNRequest"""
relay_state = None
request = AuthNRequest(relay_state=None)
if self.provider.default_relay_state != "":
relay_state = self.provider.default_relay_state
return AuthNRequest(relay_state=relay_state)
request.relay_state = self.provider.default_relay_state
if self.provider.default_name_id_policy != SAMLNameIDPolicy.UNSPECIFIED:
request.name_id_policy = self.provider.default_name_id_policy
return request

View File

@ -13,6 +13,7 @@ from authentik.crypto.models import CertificateKeyPair
from authentik.flows.models import Flow
from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping, SAMLProvider
from authentik.providers.saml.utils.encoding import PEM_FOOTER, PEM_HEADER
from authentik.sources.saml.models import SAMLNameIDPolicy
from authentik.sources.saml.processors.constants import (
NS_MAP,
NS_SAML_METADATA,
@ -46,6 +47,7 @@ class ServiceProviderMetadata:
auth_n_request_signed: bool
assertion_signed: bool
name_id_policy: SAMLNameIDPolicy
signing_keypair: CertificateKeyPair | None = None
@ -60,6 +62,7 @@ class ServiceProviderMetadata:
provider.issuer = self.entity_id
provider.sp_binding = self.acs_binding
provider.acs_url = self.acs_location
provider.default_name_id_policy = self.name_id_policy
if self.signing_keypair and self.auth_n_request_signed:
self.signing_keypair.name = f"Provider {name} - SAML Signing Certificate"
self.signing_keypair.save()
@ -148,6 +151,11 @@ class ServiceProviderMetadataParser:
if signing_keypair:
self.check_signature(root, signing_keypair)
name_id_format = descriptor.findall(f"{{{NS_SAML_METADATA}}}NameIDFormat")
name_id_policy = SAMLNameIDPolicy.UNSPECIFIED
if len(name_id_format) > 0:
name_id_policy = SAMLNameIDPolicy(name_id_format[0].text)
return ServiceProviderMetadata(
entity_id=entity_id,
acs_binding=acs_binding,
@ -155,4 +163,5 @@ class ServiceProviderMetadataParser:
auth_n_request_signed=auth_n_request_signed,
assertion_signed=assertion_signed,
signing_keypair=signing_keypair,
name_id_policy=name_id_policy,
)

View File

@ -4,7 +4,7 @@
cacheDuration="PT604800S"
entityID="http://localhost:8080/saml/metadata">
<md:SPSSODescriptor AuthnRequestsSigned="false" WantAssertionsSigned="false" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
Location="http://localhost:8080/saml/acs"
index="1" />

View File

@ -14,6 +14,7 @@ from authentik.lib.xml import lxml_from_string
from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping, SAMLProvider
from authentik.providers.saml.processors.metadata import MetadataProcessor
from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser
from authentik.sources.saml.models import SAMLNameIDPolicy
from authentik.sources.saml.processors.constants import ECDSA_SHA256, NS_MAP, NS_SAML_METADATA
@ -86,6 +87,7 @@ class TestServiceProviderMetadataParser(TestCase):
self.assertEqual(provider.acs_url, "http://localhost:8080/saml/acs")
self.assertEqual(provider.issuer, "http://localhost:8080/saml/metadata")
self.assertEqual(provider.sp_binding, SAMLBindings.POST)
self.assertEqual(provider.default_name_id_policy, SAMLNameIDPolicy.EMAIL)
self.assertEqual(
len(provider.property_mappings.all()),
len(SAMLPropertyMapping.objects.exclude(managed__isnull=True)),

View File

@ -5,6 +5,7 @@ from itertools import batched
from django.db import transaction
from pydantic import ValidationError
from pydanticscim.group import GroupMember
from pydanticscim.responses import PatchOp
from authentik.core.models import Group
from authentik.lib.sync.mapper import PropertyMappingManager
@ -19,12 +20,7 @@ from authentik.providers.scim.clients.base import SCIMClient
from authentik.providers.scim.clients.exceptions import (
SCIMRequestException,
)
from authentik.providers.scim.clients.schema import (
SCIM_GROUP_SCHEMA,
PatchOp,
PatchOperation,
PatchRequest,
)
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
from authentik.providers.scim.models import (
SCIMMapping,

View File

@ -1,7 +1,5 @@
"""Custom SCIM schemas"""
from enum import Enum
from pydantic import Field
from pydanticscim.group import Group as BaseGroup
from pydanticscim.responses import PatchOperation as BasePatchOperation
@ -67,21 +65,6 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
)
class PatchOp(str, Enum):
replace = "replace"
remove = "remove"
add = "add"
@classmethod
def _missing_(cls, value):
value = value.lower()
for member in cls:
if member.lower() == value:
return member
return None
class PatchRequest(BasePatchRequest):
"""PatchRequest which correctly sets schemas"""
@ -91,7 +74,6 @@ class PatchRequest(BasePatchRequest):
class PatchOperation(BasePatchOperation):
"""PatchOperation with optional path"""
op: PatchOp
path: str | None

View File

@ -44,7 +44,6 @@ class TestRBACRoleAPI(APITestCase):
self.assertJSONEqual(
res.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,

View File

@ -46,7 +46,6 @@ class TestRBACUserAPI(APITestCase):
self.assertJSONEqual(
res.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,

View File

@ -38,7 +38,6 @@ class TestAPIPerms(APITestCase):
self.assertJSONEqual(
res.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,
@ -74,7 +73,6 @@ class TestAPIPerms(APITestCase):
self.assertJSONEqual(
res.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,

View File

@ -9,14 +9,13 @@ https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/
import django
from channels.routing import ProtocolTypeRouter, URLRouter
from defusedxml import defuse_stdlib
from django.core.asgi import get_asgi_application
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
from authentik.root.setup import setup
# DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py
setup()
defuse_stdlib()
django.setup()

View File

@ -27,7 +27,7 @@ from structlog.stdlib import get_logger
from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp
from authentik import get_full_version
from authentik.lib.sentry import should_ignore_exception
from authentik.lib.sentry import before_send
from authentik.lib.utils.errors import exception_to_string
# set the default Django settings module for the 'celery' program.
@ -81,7 +81,7 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar
LOGGER.warning("Task failure", task_id=task_id.replace("-", ""), exc=exception)
CTX_TASK_ID.set(...)
if not should_ignore_exception(exception):
if before_send({}, {"exc_info": (None, exception, None)}) is not None:
Event.new(
EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id
).save()

View File

@ -1,49 +1,13 @@
"""authentik database backend"""
from django.core.checks import Warning
from django.db.backends.base.validation import BaseDatabaseValidation
from django_tenants.postgresql_backend.base import DatabaseWrapper as BaseDatabaseWrapper
from authentik.lib.config import CONFIG
class DatabaseValidation(BaseDatabaseValidation):
def check(self, **kwargs):
return self._check_encoding()
def _check_encoding(self):
"""Throw a warning when the server_encoding is not UTF-8 or
server_encoding and client_encoding are mismatched"""
messages = []
with self.connection.cursor() as cursor:
cursor.execute("SHOW server_encoding;")
server_encoding = cursor.fetchone()[0]
cursor.execute("SHOW client_encoding;")
client_encoding = cursor.fetchone()[0]
if server_encoding != client_encoding:
messages.append(
Warning(
"PostgreSQL Server and Client encoding are mismatched: Server: "
f"{server_encoding}, Client: {client_encoding}",
id="ak.db.W001",
)
)
if server_encoding != "UTF8":
messages.append(
Warning(
f"PostgreSQL Server encoding is not UTF8: {server_encoding}",
id="ak.db.W002",
)
)
return messages
class DatabaseWrapper(BaseDatabaseWrapper):
"""database backend which supports rotating credentials"""
validation_class = DatabaseValidation
def get_connection_params(self):
"""Refresh DB credentials before getting connection params"""
conn_params = super().get_connection_params()

View File

@ -166,6 +166,7 @@ SPECTACULAR_SETTINGS = {
"UserVerificationEnum": "authentik.stages.authenticator_webauthn.models.UserVerification",
"UserTypeEnum": "authentik.core.models.UserTypes",
"OutgoingSyncDeleteAction": "authentik.lib.sync.outgoing.models.OutgoingSyncDeleteAction",
"SAMLNameIDPolicyEnum": "authentik.sources.saml.models.SAMLNameIDPolicy",
},
"ENUM_ADD_EXPLICIT_BLANK_NULL_CHOICE": False,
"ENUM_GENERATE_CHOICE_DESCRIPTION": False,
@ -446,8 +447,6 @@ _DISALLOWED_ITEMS = [
"MIDDLEWARE",
"AUTHENTICATION_BACKENDS",
"CELERY",
"SPECTACULAR_SETTINGS",
"REST_FRAMEWORK",
]
SILENCED_SYSTEM_CHECKS = [
@ -470,8 +469,6 @@ def _update_settings(app_path: str):
TENANT_APPS.extend(getattr(settings_module, "TENANT_APPS", []))
MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", []))
AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", []))
SPECTACULAR_SETTINGS.update(getattr(settings_module, "SPECTACULAR_SETTINGS", {}))
REST_FRAMEWORK.update(getattr(settings_module, "REST_FRAMEWORK", {}))
CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {}))
for _attr in dir(settings_module):
if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS:

View File

@ -1,26 +0,0 @@
import os
import warnings
from cryptography.hazmat.backends.openssl.backend import backend
from defusedxml import defuse_stdlib
from authentik.lib.config import CONFIG
def setup():
warnings.filterwarnings("ignore", "SelectableGroups dict interface")
warnings.filterwarnings(
"ignore",
"defusedxml.lxml is no longer supported and will be removed in a future release.",
)
warnings.filterwarnings(
"ignore",
"defusedxml.cElementTree is deprecated, import from defusedxml.ElementTree instead.",
)
defuse_stdlib()
if CONFIG.get_bool("compliance.fips.enabled", False):
backend._enable_fips()
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")

View File

@ -11,8 +11,6 @@ from django.contrib.contenttypes.models import ContentType
from django.test.runner import DiscoverRunner
from structlog.stdlib import get_logger
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
from authentik.lib.config import CONFIG
from authentik.lib.sentry import sentry_init
from authentik.root.signals import post_startup, pre_startup, startup
@ -78,9 +76,6 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover
for key, value in test_config.items():
CONFIG.set(key, value)
ASN_CONTEXT_PROCESSOR.load()
GEOIP_CONTEXT_PROCESSOR.load()
sentry_init()
self.logger.debug("Test environment configured")

View File

@ -71,31 +71,37 @@ def ldap_sync_single(source_pk: str):
return
# Delete all sync tasks from the cache
DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete()
# The order of these operations needs to be preserved as each depends on the previous one(s)
# 1. User and group sync can happen simultaneously
# 2. Membership sync needs to run afterwards
# 3. Finally, user and group deletions can happen simultaneously
user_group_sync = ldap_sync_paginator(source, UserLDAPSynchronizer) + ldap_sync_paginator(
source, GroupLDAPSynchronizer
task = chain(
# User and group sync can happen at once, they have no dependencies on each other
group(
ldap_sync_paginator(source, UserLDAPSynchronizer)
+ ldap_sync_paginator(source, GroupLDAPSynchronizer),
),
# Membership sync needs to run afterwards
group(
ldap_sync_paginator(source, MembershipLDAPSynchronizer),
),
# Finally, deletions. What we'd really like to do here is something like
# ```
# user_identifiers = <ldap query>
# User.objects.exclude(
# usersourceconnection__identifier__in=user_uniqueness_identifiers,
# ).delete()
# ```
# This runs into performance issues in large installations. So instead we spread the
# work out into three steps:
# 1. Get every object from the LDAP source.
# 2. Mark every object as "safe" in the database. This is quick, but any error could
# mean deleting users which should not be deleted, so we do it immediately, in
# large chunks, and only queue the deletion step afterwards.
# 3. Delete every unmarked item. This is slow, so we spread it over many tasks in
# small chunks.
group(
ldap_sync_paginator(source, UserLDAPForwardDeletion)
+ ldap_sync_paginator(source, GroupLDAPForwardDeletion),
),
)
membership_sync = ldap_sync_paginator(source, MembershipLDAPSynchronizer)
user_group_deletion = ldap_sync_paginator(
source, UserLDAPForwardDeletion
) + ldap_sync_paginator(source, GroupLDAPForwardDeletion)
# Celery is buggy with empty groups, so we are careful only to add non-empty groups.
# See https://github.com/celery/celery/issues/9772
task_groups = []
if user_group_sync:
task_groups.append(group(user_group_sync))
if membership_sync:
task_groups.append(group(membership_sync))
if user_group_deletion:
task_groups.append(group(user_group_deletion))
all_tasks = chain(task_groups)
all_tasks()
task()
def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list:

View File

@ -0,0 +1,32 @@
# Generated by Django 5.1.11 on 2025-06-18 09:27
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_sources_saml", "0019_migrate_usersamlsourceconnection_identifier"),
]
operations = [
migrations.AlterField(
model_name="samlsource",
name="name_id_policy",
field=models.TextField(
choices=[
("urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress", "Email"),
("urn:oasis:names:tc:SAML:2.0:nameid-format:persistent", "Persistent"),
("urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName", "X509"),
(
"urn:oasis:names:tc:SAML:2.0:nameid-format:WindowsDomainQualifiedName",
"Windows",
),
("urn:oasis:names:tc:SAML:2.0:nameid-format:transient", "Transient"),
("urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified", "Unspecified"),
],
default="urn:oasis:names:tc:SAML:2.0:nameid-format:persistent",
help_text="NameID Policy sent to the IdP. Can be unset, in which case no Policy is sent.",
),
),
]

View File

@ -39,6 +39,7 @@ from authentik.sources.saml.processors.constants import (
SAML_NAME_ID_FORMAT_EMAIL,
SAML_NAME_ID_FORMAT_PERSISTENT,
SAML_NAME_ID_FORMAT_TRANSIENT,
SAML_NAME_ID_FORMAT_UNSPECIFIED,
SAML_NAME_ID_FORMAT_WINDOWS,
SAML_NAME_ID_FORMAT_X509,
SHA1,
@ -73,6 +74,7 @@ class SAMLNameIDPolicy(models.TextChoices):
X509 = SAML_NAME_ID_FORMAT_X509
WINDOWS = SAML_NAME_ID_FORMAT_WINDOWS
TRANSIENT = SAML_NAME_ID_FORMAT_TRANSIENT
UNSPECIFIED = SAML_NAME_ID_FORMAT_UNSPECIFIED
class SAMLSource(Source):

View File

@ -1,277 +0,0 @@
"""Test SCIM Group"""
from json import dumps
from uuid import uuid4
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.models import Group
from authentik.core.tests.utils import create_test_user
from authentik.events.models import Event, EventAction
from authentik.lib.generators import generate_id
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
from authentik.sources.scim.models import (
SCIMSource,
SCIMSourceGroup,
)
from authentik.sources.scim.views.v2.base import SCIM_CONTENT_TYPE
class TestSCIMGroups(APITestCase):
"""Test SCIM Group view"""
def setUp(self) -> None:
self.source = SCIMSource.objects.create(name=generate_id(), slug=generate_id())
def test_group_list(self):
"""Test full group list"""
response = self.client.get(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 200)
def test_group_list_single(self):
"""Test full group list (single group)"""
group = Group.objects.create(name=generate_id())
user = create_test_user()
group.users.add(user)
SCIMSourceGroup.objects.create(
source=self.source,
group=group,
id=str(uuid4()),
)
response = self.client.get(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
"group_id": str(group.pk),
},
),
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=200)
SCIMGroupSchema.model_validate_json(response.content, strict=True)
def test_group_create(self):
"""Test group create"""
ext_id = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
data=dumps({"displayName": generate_id(), "externalId": ext_id}),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 201)
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
self.assertTrue(
Event.objects.filter(
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
).exists()
)
def test_group_create_members(self):
"""Test group create"""
user = create_test_user()
ext_id = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
data=dumps(
{
"displayName": generate_id(),
"externalId": ext_id,
"members": [{"value": str(user.uuid)}],
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 201)
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
self.assertTrue(
Event.objects.filter(
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
).exists()
)
def test_group_create_members_empty(self):
"""Test group create"""
ext_id = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
data=dumps({"displayName": generate_id(), "externalId": ext_id, "members": []}),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 201)
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
self.assertTrue(
Event.objects.filter(
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
).exists()
)
def test_group_create_duplicate(self):
"""Test group create (duplicate)"""
group = Group.objects.create(name=generate_id())
existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
ext_id = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
data=dumps(
{"displayName": generate_id(), "externalId": ext_id, "id": str(existing.group.pk)}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 409)
self.assertJSONEqual(
response.content,
{
"detail": "Group with ID exists already.",
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"scimType": "uniqueness",
"status": 409,
},
)
def test_group_update(self):
"""Test group update"""
group = Group.objects.create(name=generate_id())
existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
ext_id = generate_id()
response = self.client.put(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
data=dumps(
{"displayName": generate_id(), "externalId": ext_id, "id": str(existing.pk)}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=200)
def test_group_update_non_existent(self):
"""Test group update"""
ext_id = generate_id()
response = self.client.put(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
"group_id": str(uuid4()),
},
),
data=dumps({"displayName": generate_id(), "externalId": ext_id, "id": ""}),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=404)
self.assertJSONEqual(
response.content,
{
"detail": "Group not found.",
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"status": 404,
},
)
def test_group_patch_add(self):
"""Test group patch"""
user = create_test_user()
group = Group.objects.create(name=generate_id())
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
response = self.client.patch(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
data=dumps(
{
"Operations": [
{
"op": "Add",
"path": "members",
"value": {"value": str(user.uuid)},
}
]
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=200)
self.assertTrue(group.users.filter(pk=user.pk).exists())
def test_group_patch_remove(self):
"""Test group patch"""
user = create_test_user()
group = Group.objects.create(name=generate_id())
group.users.add(user)
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
response = self.client.patch(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
data=dumps(
{
"Operations": [
{
"op": "remove",
"path": "members",
"value": {"value": str(user.uuid)},
}
]
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=200)
self.assertFalse(group.users.filter(pk=user.pk).exists())
def test_group_delete(self):
"""Test group delete"""
group = Group.objects.create(name=generate_id())
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
response = self.client.delete(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=204)

View File

@ -177,51 +177,3 @@ class TestSCIMUsers(APITestCase):
SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"],
"0123456789",
)
def test_user_update(self):
"""Test user update"""
user = create_test_user()
existing = SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
ext_id = generate_id()
response = self.client.put(
reverse(
"authentik_sources_scim:v2-users",
kwargs={
"source_slug": self.source.slug,
"user_id": str(user.uuid),
},
),
data=dumps(
{
"id": str(existing.pk),
"userName": generate_id(),
"externalId": ext_id,
"emails": [
{
"primary": True,
"value": user.email,
}
],
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 200)
def test_user_delete(self):
"""Test user delete"""
user = create_test_user()
SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
response = self.client.delete(
reverse(
"authentik_sources_scim:v2-users",
kwargs={
"source_slug": self.source.slug,
"user_id": str(user.uuid),
},
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 204)

View File

@ -8,7 +8,6 @@ from rest_framework.authentication import BaseAuthentication, get_authorization_
from rest_framework.request import Request
from rest_framework.views import APIView
from authentik.core.middleware import CTX_AUTH_VIA
from authentik.core.models import Token, TokenIntents, User
from authentik.sources.scim.models import SCIMSource
@ -27,7 +26,6 @@ class SCIMTokenAuth(BaseAuthentication):
_username, _, password = b64decode(key.encode()).decode().partition(":")
token = self.check_token(password, source_slug)
if token:
CTX_AUTH_VIA.set("scim_basic")
return (token.user, token)
return None
@ -54,5 +52,4 @@ class SCIMTokenAuth(BaseAuthentication):
token = self.check_token(key, source_slug)
if not token:
return None
CTX_AUTH_VIA.set("scim_token")
return (token.user, token)

View File

@ -1,11 +1,13 @@
"""SCIM Utils"""
from typing import Any
from urllib.parse import urlparse
from django.conf import settings
from django.core.paginator import Page, Paginator
from django.db.models import Q, QuerySet
from django.http import HttpRequest
from django.urls import resolve
from rest_framework.parsers import JSONParser
from rest_framework.permissions import IsAuthenticated
from rest_framework.renderers import JSONRenderer
@ -44,7 +46,7 @@ class SCIMView(APIView):
logger: BoundLogger
permission_classes = [IsAuthenticated]
parser_classes = [SCIMParser, JSONParser]
parser_classes = [SCIMParser]
renderer_classes = [SCIMRenderer]
def setup(self, request: HttpRequest, *args: Any, **kwargs: Any) -> None:
@ -54,6 +56,28 @@ class SCIMView(APIView):
def get_authenticators(self):
return [SCIMTokenAuth(self)]
def patch_resolve_value(self, raw_value: dict) -> User | Group | None:
"""Attempt to resolve a raw `value` attribute of a patch operation into
a database model"""
model = User
query = {}
if "$ref" in raw_value:
url = urlparse(raw_value["$ref"])
if match := resolve(url.path):
if match.url_name == "v2-users":
model = User
query = {"pk": int(match.kwargs["user_id"])}
elif "type" in raw_value:
match raw_value["type"]:
case "User":
model = User
query = {"pk": int(raw_value["value"])}
case "Group":
model = Group
else:
return None
return model.objects.filter(**query).first()
def filter_parse(self, request: Request):
"""Parse the path of a Patch Operation"""
path = request.query_params.get("filter")

View File

@ -1,58 +0,0 @@
from enum import Enum
from pydanticscim.responses import SCIMError as BaseSCIMError
from rest_framework.exceptions import ValidationError
class SCIMErrorTypes(Enum):
invalid_filter = "invalidFilter"
too_many = "tooMany"
uniqueness = "uniqueness"
mutability = "mutability"
invalid_syntax = "invalidSyntax"
invalid_path = "invalidPath"
no_target = "noTarget"
invalid_value = "invalidValue"
invalid_vers = "invalidVers"
sensitive = "sensitive"
class SCIMError(BaseSCIMError):
scimType: SCIMErrorTypes | None = None
detail: str | None = None
class SCIMValidationError(ValidationError):
status_code = 400
default_detail = SCIMError(scimType=SCIMErrorTypes.invalid_syntax, status=400)
def __init__(self, detail: SCIMError | None):
if detail is None:
detail = self.default_detail
detail.status = self.status_code
self.detail = detail.model_dump(mode="json", exclude_none=True)
class SCIMConflictError(SCIMValidationError):
status_code = 409
def __init__(self, detail: str):
super().__init__(
SCIMError(
detail=detail,
scimType=SCIMErrorTypes.uniqueness,
status=self.status_code,
)
)
class SCIMNotFoundError(SCIMValidationError):
status_code = 404
def __init__(self, detail: str):
super().__init__(
SCIMError(
detail=detail,
status=self.status_code,
)
)

View File

@ -4,25 +4,19 @@ from uuid import uuid4
from django.db.models import Q
from django.db.transaction import atomic
from django.http import QueryDict
from django.http import Http404, QueryDict
from django.urls import reverse
from pydantic import ValidationError as PydanticValidationError
from pydanticscim.group import GroupMember
from rest_framework.exceptions import ValidationError
from rest_framework.request import Request
from rest_framework.response import Response
from scim2_filter_parser.attr_paths import AttrPath
from authentik.core.models import Group, User
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOp, PatchOperation
from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
from authentik.providers.scim.clients.schema import Group as SCIMGroupModel
from authentik.sources.scim.models import SCIMSourceGroup
from authentik.sources.scim.views.v2.base import SCIMObjectView
from authentik.sources.scim.views.v2.exceptions import (
SCIMConflictError,
SCIMNotFoundError,
SCIMValidationError,
)
class GroupsView(SCIMObjectView):
@ -33,7 +27,7 @@ class GroupsView(SCIMObjectView):
def group_to_scim(self, scim_group: SCIMSourceGroup) -> dict:
"""Convert Group to SCIM data"""
payload = SCIMGroupModel(
schemas=[SCIM_GROUP_SCHEMA],
schemas=[SCIM_USER_SCHEMA],
id=str(scim_group.group.pk),
externalId=scim_group.id,
displayName=scim_group.group.name,
@ -64,7 +58,7 @@ class GroupsView(SCIMObjectView):
if group_id:
connection = base_query.filter(source=self.source, group__group_uuid=group_id).first()
if not connection:
raise SCIMNotFoundError("Group not found.")
raise Http404
return Response(self.group_to_scim(connection))
connections = (
base_query.filter(source=self.source).order_by("pk").filter(self.filter_parse(request))
@ -125,7 +119,7 @@ class GroupsView(SCIMObjectView):
).first()
if connection:
self.logger.debug("Found existing group")
raise SCIMConflictError("Group with ID exists already.")
return Response(status=409)
connection = self.update_group(None, request.data)
return Response(self.group_to_scim(connection), status=201)
@ -135,44 +129,10 @@ class GroupsView(SCIMObjectView):
source=self.source, group__group_uuid=group_id
).first()
if not connection:
raise SCIMNotFoundError("Group not found.")
raise Http404
connection = self.update_group(connection, request.data)
return Response(self.group_to_scim(connection), status=200)
@atomic
def patch(self, request: Request, group_id: str, **kwargs) -> Response:
"""Patch group handler"""
connection = SCIMSourceGroup.objects.filter(
source=self.source, group__group_uuid=group_id
).first()
if not connection:
raise SCIMNotFoundError("Group not found.")
for _op in request.data.get("Operations", []):
operation = PatchOperation.model_validate(_op)
if operation.op.lower() not in ["add", "remove", "replace"]:
raise SCIMValidationError()
attr_path = AttrPath(f'{operation.path} eq ""', {})
if attr_path.first_path == ("members", None, None):
# FIXME: this can probably be de-duplicated
if operation.op == PatchOp.add:
if not isinstance(operation.value, list):
operation.value = [operation.value]
query = Q()
for member in operation.value:
query |= Q(uuid=member["value"])
if query:
connection.group.users.add(*User.objects.filter(query))
elif operation.op == PatchOp.remove:
if not isinstance(operation.value, list):
operation.value = [operation.value]
query = Q()
for member in operation.value:
query |= Q(uuid=member["value"])
if query:
connection.group.users.remove(*User.objects.filter(query))
return Response(self.group_to_scim(connection), status=200)
@atomic
def delete(self, request: Request, group_id: str, **kwargs) -> Response:
"""Delete group handler"""
@ -180,7 +140,7 @@ class GroupsView(SCIMObjectView):
source=self.source, group__group_uuid=group_id
).first()
if not connection:
raise SCIMNotFoundError("Group not found.")
raise Http404
connection.group.delete()
connection.delete()
return Response(status=204)

View File

@ -1,11 +1,11 @@
"""SCIM Meta views"""
from django.http import Http404
from django.urls import reverse
from rest_framework.request import Request
from rest_framework.response import Response
from authentik.sources.scim.views.v2.base import SCIMView
from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError
class ResourceTypesView(SCIMView):
@ -138,7 +138,7 @@ class ResourceTypesView(SCIMView):
resource = [x for x in resource_types if x.get("id") == resource_type]
if resource:
return Response(resource[0])
raise SCIMNotFoundError("Resource not found.")
raise Http404
return Response(
{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],

View File

@ -3,12 +3,12 @@
from json import loads
from django.conf import settings
from django.http import Http404
from django.urls import reverse
from rest_framework.request import Request
from rest_framework.response import Response
from authentik.sources.scim.views.v2.base import SCIMView
from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError
with open(
settings.BASE_DIR / "authentik" / "sources" / "scim" / "schemas" / "schema.json",
@ -44,7 +44,7 @@ class SchemaView(SCIMView):
schema = [x for x in schemas if x.get("id") == schema_uri]
if schema:
return Response(schema[0])
raise SCIMNotFoundError("Schema not found.")
raise Http404
return Response(
{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],

View File

@ -33,8 +33,6 @@ class ServiceProviderConfigView(SCIMView):
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"],
"authenticationSchemes": auth_schemas,
# We only support patch for groups currently, so don't broadly advertise it.
# Implementations that require Group patch will use it regardless of this flag.
"patch": {"supported": False},
"bulk": {"supported": False, "maxOperations": 0, "maxPayloadSize": 0},
"filter": {

View File

@ -4,7 +4,7 @@ from uuid import uuid4
from django.db.models import Q
from django.db.transaction import atomic
from django.http import QueryDict
from django.http import Http404, QueryDict
from django.urls import reverse
from pydanticscim.user import Email, EmailKind, Name
from rest_framework.exceptions import ValidationError
@ -16,7 +16,6 @@ from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
from authentik.providers.scim.clients.schema import User as SCIMUserModel
from authentik.sources.scim.models import SCIMSourceUser
from authentik.sources.scim.views.v2.base import SCIMObjectView
from authentik.sources.scim.views.v2.exceptions import SCIMConflictError, SCIMNotFoundError
class UsersView(SCIMObjectView):
@ -70,7 +69,7 @@ class UsersView(SCIMObjectView):
.first()
)
if not connection:
raise SCIMNotFoundError("User not found.")
raise Http404
return Response(self.user_to_scim(connection))
connections = (
SCIMSourceUser.objects.filter(source=self.source).select_related("user").order_by("pk")
@ -123,7 +122,7 @@ class UsersView(SCIMObjectView):
).first()
if connection:
self.logger.debug("Found existing user")
raise SCIMConflictError("Group with ID exists already.")
return Response(status=409)
connection = self.update_user(None, request.data)
return Response(self.user_to_scim(connection), status=201)
@ -131,7 +130,7 @@ class UsersView(SCIMObjectView):
"""Update user handler"""
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
if not connection:
raise SCIMNotFoundError("User not found.")
raise Http404
self.update_user(connection, request.data)
return Response(self.user_to_scim(connection), status=200)
@ -140,7 +139,7 @@ class UsersView(SCIMObjectView):
"""Delete user handler"""
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
if not connection:
raise SCIMNotFoundError("User not found.")
raise Http404
connection.user.delete()
connection.delete()
return Response(status=204)

View File

@ -1,7 +1,6 @@
"""Validation stage challenge checking"""
from json import loads
from typing import TYPE_CHECKING
from urllib.parse import urlencode
from django.http import HttpRequest
@ -37,12 +36,10 @@ from authentik.stages.authenticator_email.models import EmailDevice
from authentik.stages.authenticator_sms.models import SMSDevice
from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses
from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice
from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id
LOGGER = get_logger()
if TYPE_CHECKING:
from authentik.stages.authenticator_validate.stage import AuthenticatorValidateStageView
class DeviceChallenge(PassiveSerializer):
@ -55,11 +52,11 @@ class DeviceChallenge(PassiveSerializer):
def get_challenge_for_device(
stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage, device: Device
request: HttpRequest, stage: AuthenticatorValidateStage, device: Device
) -> dict:
"""Generate challenge for a single device"""
if isinstance(device, WebAuthnDevice):
return get_webauthn_challenge(stage_view, stage, device)
return get_webauthn_challenge(request, stage, device)
if isinstance(device, EmailDevice):
return {"email": mask_email(device.email)}
# Code-based challenges have no hints
@ -67,30 +64,26 @@ def get_challenge_for_device(
def get_webauthn_challenge_without_user(
stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage
request: HttpRequest, stage: AuthenticatorValidateStage
) -> dict:
"""Same as `get_webauthn_challenge`, but allows any client device. We can then later check
who the device belongs to."""
stage_view.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None)
request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
authentication_options = generate_authentication_options(
rp_id=get_rp_id(stage_view.request),
rp_id=get_rp_id(request),
allow_credentials=[],
user_verification=UserVerificationRequirement(stage.webauthn_user_verification),
)
stage_view.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = (
authentication_options.challenge
)
request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge
return loads(options_to_json(authentication_options))
def get_webauthn_challenge(
stage_view: "AuthenticatorValidateStageView",
stage: AuthenticatorValidateStage,
device: WebAuthnDevice | None = None,
request: HttpRequest, stage: AuthenticatorValidateStage, device: WebAuthnDevice | None = None
) -> dict:
"""Send the client a challenge that we'll check later"""
stage_view.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None)
request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
allowed_credentials = []
@ -101,14 +94,12 @@ def get_webauthn_challenge(
allowed_credentials.append(user_device.descriptor)
authentication_options = generate_authentication_options(
rp_id=get_rp_id(stage_view.request),
rp_id=get_rp_id(request),
allow_credentials=allowed_credentials,
user_verification=UserVerificationRequirement(stage.webauthn_user_verification),
)
stage_view.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = (
authentication_options.challenge
)
request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge
return loads(options_to_json(authentication_options))
@ -155,7 +146,7 @@ def validate_challenge_code(code: str, stage_view: StageView, user: User) -> Dev
def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -> Device:
"""Validate WebAuthn Challenge"""
request = stage_view.request
challenge = stage_view.executor.plan.context.get(PLAN_CONTEXT_WEBAUTHN_CHALLENGE)
challenge = request.session.get(SESSION_KEY_WEBAUTHN_CHALLENGE)
stage: AuthenticatorValidateStage = stage_view.executor.current_stage
try:
credential = parse_authentication_credential_json(data)

View File

@ -224,7 +224,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
data={
"device_class": device_class,
"device_uid": device.pk,
"challenge": get_challenge_for_device(self, stage, device),
"challenge": get_challenge_for_device(self.request, stage, device),
"last_used": device.last_used,
}
)
@ -243,7 +243,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
"device_class": DeviceClasses.WEBAUTHN,
"device_uid": -1,
"challenge": get_webauthn_challenge_without_user(
self,
self.request,
self.executor.current_stage,
),
"last_used": None,

View File

@ -31,7 +31,7 @@ from authentik.stages.authenticator_webauthn.models import (
WebAuthnDevice,
WebAuthnDeviceType,
)
from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import
from authentik.stages.identification.models import IdentificationStage, UserFields
from authentik.stages.user_login.models import UserLoginStage
@ -103,11 +103,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
device_classes=[DeviceClasses.WEBAUTHN],
webauthn_user_verification=UserVerification.PREFERRED,
)
plan = FlowPlan("")
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request
)
challenge = get_challenge_for_device(stage_view, stage, webauthn_device)
challenge = get_challenge_for_device(request, stage, webauthn_device)
del challenge["challenge"]
self.assertEqual(
challenge,
@ -126,9 +122,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
with self.assertRaises(ValidationError):
validate_challenge_webauthn(
{},
StageView(FlowExecutorView(current_stage=stage, plan=plan), request=request),
self.user,
{}, StageView(FlowExecutorView(current_stage=stage), request=request), self.user
)
def test_device_challenge_webauthn_restricted(self):
@ -199,35 +193,22 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
sign_count=0,
rp_id=generate_id(),
)
plan = FlowPlan("")
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
)
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request
)
challenge = get_challenge_for_device(stage_view, stage, webauthn_device)
challenge = get_challenge_for_device(request, stage, webauthn_device)
webauthn_challenge = request.session[SESSION_KEY_WEBAUTHN_CHALLENGE]
self.assertEqual(
challenge["allowCredentials"],
[
{
"id": "QKZ97ASJAOIDyipAs6mKUxDUZgDrWrbAsUb5leL7-oU",
"type": "public-key",
}
],
)
self.assertIsNotNone(challenge["challenge"])
self.assertEqual(
challenge["rpId"],
"testserver",
)
self.assertEqual(
challenge["timeout"],
60000,
)
self.assertEqual(
challenge["userVerification"],
"preferred",
challenge,
{
"allowCredentials": [
{
"id": "QKZ97ASJAOIDyipAs6mKUxDUZgDrWrbAsUb5leL7-oU",
"type": "public-key",
}
],
"challenge": bytes_to_base64url(webauthn_challenge),
"rpId": "testserver",
"timeout": 60000,
"userVerification": "preferred",
},
)
def test_get_challenge_userless(self):
@ -247,16 +228,18 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
sign_count=0,
rp_id=generate_id(),
)
plan = FlowPlan("")
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request
challenge = get_webauthn_challenge_without_user(request, stage)
webauthn_challenge = request.session[SESSION_KEY_WEBAUTHN_CHALLENGE]
self.assertEqual(
challenge,
{
"allowCredentials": [],
"challenge": bytes_to_base64url(webauthn_challenge),
"rpId": "testserver",
"timeout": 60000,
"userVerification": "preferred",
},
)
challenge = get_webauthn_challenge_without_user(stage_view, stage)
self.assertEqual(challenge["allowCredentials"], [])
self.assertIsNotNone(challenge["challenge"])
self.assertEqual(challenge["rpId"], "testserver")
self.assertEqual(challenge["timeout"], 60000)
self.assertEqual(challenge["userVerification"], "preferred")
def test_validate_challenge_unrestricted(self):
"""Test webauthn authentication (unrestricted webauthn device)"""
@ -292,10 +275,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
"last_used": None,
}
]
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ"
)
session[SESSION_KEY_PLAN] = plan
session.save()
response = self.client.post(
@ -369,10 +352,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
"last_used": None,
}
]
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ"
)
session[SESSION_KEY_PLAN] = plan
session.save()
response = self.client.post(
@ -450,10 +433,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
"last_used": None,
}
]
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
)
session[SESSION_KEY_PLAN] = plan
session.save()
response = self.client.post(
@ -513,14 +496,17 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
not_configured_action=NotConfiguredAction.CONFIGURE,
device_classes=[DeviceClasses.WEBAUTHN],
)
plan = FlowPlan(flow.pk.hex)
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=flow, current_stage=stage), request=request
)
request = get_request("/")
request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
)
request.session.save()
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=flow, current_stage=stage, plan=plan), request=request
FlowExecutorView(flow=flow, current_stage=stage), request=request
)
request.META["SERVER_NAME"] = "localhost"
request.META["SERVER_PORT"] = "9000"

View File

@ -25,7 +25,6 @@ class AuthenticatorWebAuthnStageSerializer(StageSerializer):
"resident_key_requirement",
"device_type_restrictions",
"device_type_restrictions_obj",
"max_attempts",
]

View File

@ -1,21 +0,0 @@
# Generated by Django 5.1.11 on 2025-06-13 22:41
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
(
"authentik_stages_authenticator_webauthn",
"0012_webauthndevice_created_webauthndevice_last_updated_and_more",
),
]
operations = [
migrations.AddField(
model_name="authenticatorwebauthnstage",
name="max_attempts",
field=models.PositiveIntegerField(default=0),
),
]

View File

@ -84,8 +84,6 @@ class AuthenticatorWebAuthnStage(ConfigurableStage, FriendlyNamedStage, Stage):
device_type_restrictions = models.ManyToManyField("WebAuthnDeviceType", blank=True)
max_attempts = models.PositiveIntegerField(default=0)
@property
def serializer(self) -> type[BaseSerializer]:
from authentik.stages.authenticator_webauthn.api.stages import (

View File

@ -5,13 +5,12 @@ from uuid import UUID
from django.http import HttpRequest, HttpResponse
from django.http.request import QueryDict
from django.utils.translation import gettext as __
from django.utils.translation import gettext_lazy as _
from rest_framework.fields import CharField
from rest_framework.serializers import ValidationError
from webauthn import options_to_json
from webauthn.helpers.bytes_to_base64url import bytes_to_base64url
from webauthn.helpers.exceptions import WebAuthnException
from webauthn.helpers.exceptions import InvalidRegistrationResponse
from webauthn.helpers.structs import (
AttestationConveyancePreference,
AuthenticatorAttachment,
@ -42,8 +41,7 @@ from authentik.stages.authenticator_webauthn.models import (
)
from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id
PLAN_CONTEXT_WEBAUTHN_CHALLENGE = "goauthentik.io/stages/authenticator_webauthn/challenge"
PLAN_CONTEXT_WEBAUTHN_ATTEMPT = "goauthentik.io/stages/authenticator_webauthn/attempt"
SESSION_KEY_WEBAUTHN_CHALLENGE = "authentik/stages/authenticator_webauthn/challenge"
class AuthenticatorWebAuthnChallenge(WithUserInfoChallenge):
@ -64,7 +62,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse):
def validate_response(self, response: dict) -> dict:
"""Validate webauthn challenge response"""
challenge = self.stage.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE]
challenge = self.request.session[SESSION_KEY_WEBAUTHN_CHALLENGE]
try:
registration: VerifiedRegistration = verify_registration_response(
@ -73,7 +71,7 @@ class AuthenticatorWebAuthnChallengeResponse(ChallengeResponse):
expected_rp_id=get_rp_id(self.request),
expected_origin=get_origin(self.request),
)
except WebAuthnException as exc:
except InvalidRegistrationResponse as exc:
self.stage.logger.warning("registration failed", exc=exc)
raise ValidationError(f"Registration failed. Error: {exc}") from None
@ -116,10 +114,9 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
response_class = AuthenticatorWebAuthnChallengeResponse
def get_challenge(self, *args, **kwargs) -> Challenge:
# clear session variables prior to starting a new registration
self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
stage: AuthenticatorWebAuthnStage = self.executor.current_stage
self.executor.plan.context.setdefault(PLAN_CONTEXT_WEBAUTHN_ATTEMPT, 0)
# clear flow variables prior to starting a new registration
self.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None)
user = self.get_pending_user()
# library accepts none so we store null in the database, but if there is a value
@ -142,7 +139,8 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
attestation=AttestationConveyancePreference.DIRECT,
)
self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = registration_options.challenge
self.request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = registration_options.challenge
self.request.session.save()
return AuthenticatorWebAuthnChallenge(
data={
"registration": loads(options_to_json(registration_options)),
@ -155,24 +153,6 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
response.user = self.get_pending_user()
return response
def challenge_invalid(self, response):
stage: AuthenticatorWebAuthnStage = self.executor.current_stage
self.executor.plan.context.setdefault(PLAN_CONTEXT_WEBAUTHN_ATTEMPT, 0)
self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_ATTEMPT] += 1
if (
stage.max_attempts > 0
and self.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_ATTEMPT] >= stage.max_attempts
):
return self.executor.stage_invalid(
__(
"Exceeded maximum attempts. "
"Contact your {brand} administrator for help.".format(
brand=self.request.brand.branding_title
)
)
)
return super().challenge_invalid(response)
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
# Webauthn Challenge has already been validated
webauthn_credential: VerifiedRegistration = response.validated_data["response"]
@ -199,3 +179,6 @@ class AuthenticatorWebAuthnStageView(ChallengeStageView):
else:
return self.executor.stage_invalid("Device with Credential ID already exists.")
return self.executor.stage_ok()
def cleanup(self):
self.request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)

View File

@ -18,7 +18,7 @@ from authentik.stages.authenticator_webauthn.models import (
WebAuthnDevice,
WebAuthnDeviceType,
)
from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import
@ -57,9 +57,6 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
)
plan: FlowPlan = self.client.session[SESSION_KEY_PLAN]
self.assertEqual(response.status_code, 200)
session = self.client.session
self.assertStageResponse(
@ -73,7 +70,7 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
"name": self.user.username,
"displayName": self.user.name,
},
"challenge": bytes_to_base64url(plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE]),
"challenge": bytes_to_base64url(session[SESSION_KEY_WEBAUTHN_CHALLENGE]),
"pubKeyCredParams": [
{"type": "public-key", "alg": -7},
{"type": "public-key", "alg": -8},
@ -100,11 +97,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
"""Test registration"""
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session.save()
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
@ -149,11 +146,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session.save()
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
@ -212,11 +209,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session.save()
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
@ -262,11 +259,11 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session.save()
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
@ -301,109 +298,3 @@ class TestAuthenticatorWebAuthnStage(FlowTestCase):
self.assertEqual(response.status_code, 200)
self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
self.assertTrue(WebAuthnDevice.objects.filter(user=self.user).exists())
def test_register_max_retries(self):
"""Test registration (exceeding max retries)"""
self.stage.max_attempts = 2
self.stage.save()
plan = FlowPlan(flow_pk=self.flow.pk.hex, bindings=[self.binding], markers=[StageMarker()])
plan.context[PLAN_CONTEXT_PENDING_USER] = self.user
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = b64decode(
b"03Xodi54gKsfnP5I9VFfhaGXVVE2NUyZpBBXns/JI+x6V9RY2Tw2QmxRJkhh7174EkRazUntIwjMVY9bFG60Lw=="
)
session = self.client.session
session[SESSION_KEY_PLAN] = plan
session.save()
# first failed request
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
data={
"component": "ak-stage-authenticator-webauthn",
"response": {
"id": "kqnmrVLnDG-OwsSNHkihYZaNz5s",
"rawId": "kqnmrVLnDG-OwsSNHkihYZaNz5s",
"type": "public-key",
"registrationClientExtensions": "{}",
"response": {
"clientDataJSON": (
"eyJ0eXBlIjoid2ViYXV0aG4uY3JlYXRlIiwiY2hhbGxlbmd"
"lIjoiMDNYb2RpNTRnS3NmblA1STlWRmZoYUdYVlZFMk5VeV"
"pwQkJYbnNfSkkteDZWOVJZMlR3MlFteFJKa2hoNzE3NEVrU"
"mF6VW50SXdqTVZZOWJGRzYwTHciLCJvcmlnaW4iOiJodHRw"
"Oi8vbG9jYWxob3N0OjkwMDAiLCJjcm9zc09yaWdpbiI6ZmF"
),
"attestationObject": (
"o2NmbXRkbm9uZWdhdHRTdG10oGhhdXRoRGF0YViYSZYN5Yg"
"OjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2NdAAAAAPv8MA"
"cVTk7MjAtuAgVX170AFJKp5q1S5wxvjsLEjR5IoWGWjc-bp"
"QECAyYgASFYIKtcZHPumH37XHs0IM1v3pUBRIqHVV_SE-Lq"
"2zpJAOVXIlgg74Fg_WdB0kuLYqCKbxogkEPaVtR_iR3IyQFIJAXBzds"
),
},
},
},
SERVER_NAME="localhost",
SERVER_PORT="9000",
)
self.assertEqual(response.status_code, 200)
self.assertStageResponse(
response,
flow=self.flow,
component="ak-stage-authenticator-webauthn",
response_errors={
"response": [
{
"string": (
"Registration failed. Error: Unable to decode "
"client_data_json bytes as JSON"
),
"code": "invalid",
}
]
},
)
self.assertFalse(WebAuthnDevice.objects.filter(user=self.user).exists())
# Second failed request
response = self.client.post(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}),
data={
"component": "ak-stage-authenticator-webauthn",
"response": {
"id": "kqnmrVLnDG-OwsSNHkihYZaNz5s",
"rawId": "kqnmrVLnDG-OwsSNHkihYZaNz5s",
"type": "public-key",
"registrationClientExtensions": "{}",
"response": {
"clientDataJSON": (
"eyJ0eXBlIjoid2ViYXV0aG4uY3JlYXRlIiwiY2hhbGxlbmd"
"lIjoiMDNYb2RpNTRnS3NmblA1STlWRmZoYUdYVlZFMk5VeV"
"pwQkJYbnNfSkkteDZWOVJZMlR3MlFteFJKa2hoNzE3NEVrU"
"mF6VW50SXdqTVZZOWJGRzYwTHciLCJvcmlnaW4iOiJodHRw"
"Oi8vbG9jYWxob3N0OjkwMDAiLCJjcm9zc09yaWdpbiI6ZmF"
),
"attestationObject": (
"o2NmbXRkbm9uZWdhdHRTdG10oGhhdXRoRGF0YViYSZYN5Yg"
"OjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2NdAAAAAPv8MA"
"cVTk7MjAtuAgVX170AFJKp5q1S5wxvjsLEjR5IoWGWjc-bp"
"QECAyYgASFYIKtcZHPumH37XHs0IM1v3pUBRIqHVV_SE-Lq"
"2zpJAOVXIlgg74Fg_WdB0kuLYqCKbxogkEPaVtR_iR3IyQFIJAXBzds"
),
},
},
},
SERVER_NAME="localhost",
SERVER_PORT="9000",
)
self.assertEqual(response.status_code, 200)
self.assertStageResponse(
response,
flow=self.flow,
component="ak-stage-access-denied",
error_message=(
"Exceeded maximum attempts. Contact your authentik administrator for help."
),
)
self.assertFalse(WebAuthnDevice.objects.filter(user=self.user).exists())

View File

@ -101,9 +101,9 @@ class BoundSessionMiddleware(SessionMiddleware):
SESSION_KEY_BINDING_GEO, GeoIPBinding.NO_BINDING
)
if configured_binding_net != NetworkBinding.NO_BINDING:
BoundSessionMiddleware.recheck_session_net(configured_binding_net, last_ip, new_ip)
self.recheck_session_net(configured_binding_net, last_ip, new_ip)
if configured_binding_geo != GeoIPBinding.NO_BINDING:
BoundSessionMiddleware.recheck_session_geo(configured_binding_geo, last_ip, new_ip)
self.recheck_session_geo(configured_binding_geo, last_ip, new_ip)
# If we got to this point without any error being raised, we need to
# update the last saved IP to the current one
if SESSION_KEY_BINDING_NET in request.session or SESSION_KEY_BINDING_GEO in request.session:
@ -111,8 +111,7 @@ class BoundSessionMiddleware(SessionMiddleware):
# (== basically requires the user to be logged in)
request.session[request.session.model.Keys.LAST_IP] = new_ip
@staticmethod
def recheck_session_net(binding: NetworkBinding, last_ip: str, new_ip: str):
def recheck_session_net(self, binding: NetworkBinding, last_ip: str, new_ip: str):
"""Check network/ASN binding"""
last_asn = ASN_CONTEXT_PROCESSOR.asn(last_ip)
new_asn = ASN_CONTEXT_PROCESSOR.asn(new_ip)
@ -159,8 +158,7 @@ class BoundSessionMiddleware(SessionMiddleware):
new_ip,
)
@staticmethod
def recheck_session_geo(binding: GeoIPBinding, last_ip: str, new_ip: str):
def recheck_session_geo(self, binding: GeoIPBinding, last_ip: str, new_ip: str):
"""Check GeoIP binding"""
last_geo = GEOIP_CONTEXT_PROCESSOR.city(last_ip)
new_geo = GEOIP_CONTEXT_PROCESSOR.city(new_ip)
@ -181,8 +179,8 @@ class BoundSessionMiddleware(SessionMiddleware):
if last_geo.continent != new_geo.continent:
raise SessionBindingBroken(
"geoip.continent",
last_geo.continent.to_dict(),
new_geo.continent.to_dict(),
last_geo.continent,
new_geo.continent,
last_ip,
new_ip,
)
@ -194,8 +192,8 @@ class BoundSessionMiddleware(SessionMiddleware):
if last_geo.country != new_geo.country:
raise SessionBindingBroken(
"geoip.country",
last_geo.country.to_dict(),
new_geo.country.to_dict(),
last_geo.country,
new_geo.country,
last_ip,
new_ip,
)
@ -204,8 +202,8 @@ class BoundSessionMiddleware(SessionMiddleware):
if last_geo.city != new_geo.city:
raise SessionBindingBroken(
"geoip.city",
last_geo.city.to_dict(),
new_geo.city.to_dict(),
last_geo.city,
new_geo.city,
last_ip,
new_ip,
)

View File

@ -3,7 +3,6 @@
from time import sleep
from unittest.mock import patch
from django.http import HttpRequest
from django.urls import reverse
from django.utils.timezone import now
@ -18,12 +17,7 @@ from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.generators import generate_id
from authentik.lib.utils.time import timedelta_from_string
from authentik.root.middleware import ClientIPMiddleware
from authentik.stages.user_login.middleware import (
BoundSessionMiddleware,
SessionBindingBroken,
logout_extra,
)
from authentik.stages.user_login.models import GeoIPBinding, NetworkBinding, UserLoginStage
from authentik.stages.user_login.models import UserLoginStage
class TestUserLoginStage(FlowTestCase):
@ -198,52 +192,3 @@ class TestUserLoginStage(FlowTestCase):
self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
response = self.client.get(reverse("authentik_api:application-list"))
self.assertEqual(response.status_code, 403)
def test_binding_net_break_log(self):
"""Test logout_extra with exception"""
# IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-ASN-Test.json
for args, expect in [
[[NetworkBinding.BIND_ASN, "8.8.8.8", "8.8.8.8"], ["network.missing"]],
[[NetworkBinding.BIND_ASN, "1.0.0.1", "1.128.0.1"], ["network.asn"]],
[
[NetworkBinding.BIND_ASN_NETWORK, "12.81.96.1", "12.81.128.1"],
["network.asn_network"],
],
[[NetworkBinding.BIND_ASN_NETWORK_IP, "1.0.0.1", "1.0.0.2"], ["network.ip"]],
]:
with self.subTest(args[0]):
with self.assertRaises(SessionBindingBroken) as cm:
BoundSessionMiddleware.recheck_session_net(*args)
self.assertEqual(cm.exception.reason, expect[0])
# Ensure the request can be logged without throwing errors
self.client.force_login(self.user)
request = HttpRequest()
request.session = self.client.session
request.user = self.user
logout_extra(request, cm.exception)
def test_binding_geo_break_log(self):
"""Test logout_extra with exception"""
# IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
for args, expect in [
[[GeoIPBinding.BIND_CONTINENT, "8.8.8.8", "8.8.8.8"], ["geoip.missing"]],
[[GeoIPBinding.BIND_CONTINENT, "2.125.160.216", "67.43.156.1"], ["geoip.continent"]],
[
[GeoIPBinding.BIND_CONTINENT_COUNTRY, "81.2.69.142", "89.160.20.112"],
["geoip.country"],
],
[
[GeoIPBinding.BIND_CONTINENT_COUNTRY_CITY, "2.125.160.216", "81.2.69.142"],
["geoip.city"],
],
]:
with self.subTest(args[0]):
with self.assertRaises(SessionBindingBroken) as cm:
BoundSessionMiddleware.recheck_session_geo(*args)
self.assertEqual(cm.exception.reason, expect[0])
# Ensure the request can be logged without throwing errors
self.client.force_login(self.user)
request = HttpRequest()
request.session = self.client.session
request.user = self.user
logout_extra(request, cm.exception)

View File

@ -16,7 +16,6 @@ def check_embedded_outpost_disabled(app_configs, **kwargs):
"Embedded outpost must be disabled when tenants API is enabled.",
hint="Disable embedded outpost by setting outposts.disable_embedded_outpost to "
"True, or disable the tenants API by setting tenants.enabled to False",
id="ak.tenants.E001",
)
]
return []

View File

@ -6628,16 +6628,11 @@
"title": "Severity",
"description": "Controls which severity level the created notifications will have."
},
"destination_group": {
"group": {
"type": "string",
"format": "uuid",
"title": "Destination group",
"title": "Group",
"description": "Define which group of users this notification should be sent and shown to. If left empty, Notification won't ben sent."
},
"destination_event_user": {
"type": "boolean",
"title": "Destination event user",
"description": "When enabled, notification will be sent to user the user that triggered the event.When destination_group is configured, notification is sent to both."
}
},
"required": []
@ -7345,7 +7340,6 @@
"authentik.enterprise.providers.google_workspace",
"authentik.enterprise.providers.microsoft_entra",
"authentik.enterprise.providers.ssf",
"authentik.enterprise.search",
"authentik.enterprise.stages.authenticator_endpoint_gdtc",
"authentik.enterprise.stages.mtls",
"authentik.enterprise.stages.source",
@ -9239,6 +9233,18 @@
"type": "string",
"title": "Default relay state",
"description": "Default relay_state value for IDP-initiated logins"
},
"default_name_id_policy": {
"type": "string",
"enum": [
"urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress",
"urn:oasis:names:tc:SAML:2.0:nameid-format:persistent",
"urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName",
"urn:oasis:names:tc:SAML:2.0:nameid-format:WindowsDomainQualifiedName",
"urn:oasis:names:tc:SAML:2.0:nameid-format:transient",
"urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified"
],
"title": "Default name id policy"
}
},
"required": []
@ -11661,7 +11667,8 @@
"urn:oasis:names:tc:SAML:2.0:nameid-format:persistent",
"urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName",
"urn:oasis:names:tc:SAML:2.0:nameid-format:WindowsDomainQualifiedName",
"urn:oasis:names:tc:SAML:2.0:nameid-format:transient"
"urn:oasis:names:tc:SAML:2.0:nameid-format:transient",
"urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified"
],
"title": "Name id policy",
"description": "NameID Policy sent to the IdP. Can be unset, in which case no Policy is sent."
@ -13310,12 +13317,6 @@
"format": "uuid"
},
"title": "Device type restrictions"
},
"max_attempts": {
"type": "integer",
"minimum": 0,
"maximum": 2147483647,
"title": "Max attempts"
}
},
"required": []

4
go.mod
View File

@ -6,7 +6,7 @@ require (
beryju.io/ldap v0.1.0
github.com/avast/retry-go/v4 v4.6.1
github.com/coreos/go-oidc/v3 v3.14.1
github.com/getsentry/sentry-go v0.34.0
github.com/getsentry/sentry-go v0.33.0
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
github.com/go-ldap/ldap/v3 v3.4.11
github.com/go-openapi/runtime v0.28.0
@ -29,7 +29,7 @@ require (
github.com/spf13/cobra v1.9.1
github.com/stretchr/testify v1.10.0
github.com/wwt/guac v1.3.2
goauthentik.io/api/v3 v3.2025062.4
goauthentik.io/api/v3 v3.2025062.1
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.15.0

8
go.sum
View File

@ -71,8 +71,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/getsentry/sentry-go v0.34.0 h1:1FCHBVp8TfSc8L10zqSwXUZNiOSF+10qw4czjarTiY4=
github.com/getsentry/sentry-go v0.34.0/go.mod h1:C55omcY9ChRQIUcVcGcs+Zdy4ZpQGvNJ7JYHIoSWOtE=
github.com/getsentry/sentry-go v0.33.0 h1:YWyDii0KGVov3xOaamOnF0mjOrqSjBqwv48UEzn7QFg=
github.com/getsentry/sentry-go v0.33.0/go.mod h1:C55omcY9ChRQIUcVcGcs+Zdy4ZpQGvNJ7JYHIoSWOtE=
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
@ -298,8 +298,8 @@ go.opentelemetry.io/otel/trace v1.24.0 h1:CsKnnL4dUAr/0llH9FKuc698G04IrpWV0MQA/Y
go.opentelemetry.io/otel/trace v1.24.0/go.mod h1:HPc3Xr/cOApsBI154IU0OI0HJexz+aw5uPdbs3UCjNU=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
goauthentik.io/api/v3 v3.2025062.4 h1:HuyL12kKserXT2w+wCDUYNRSeyCCGX81wU9SRCPuxDo=
goauthentik.io/api/v3 v3.2025062.4/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
goauthentik.io/api/v3 v3.2025062.1 h1:spvILDpDDWJNO3pM6QGqmryx6NvSchr1E8H60J/XUCA=
goauthentik.io/api/v3 v3.2025062.1/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=

View File

@ -9,7 +9,7 @@
"version": "0.0.0",
"license": "MIT",
"devDependencies": {
"aws-cdk": "^2.1019.1",
"aws-cdk": "^2.1018.1",
"cross-env": "^7.0.3"
},
"engines": {
@ -17,9 +17,9 @@
}
},
"node_modules/aws-cdk": {
"version": "2.1019.1",
"resolved": "https://registry.npmjs.org/aws-cdk/-/aws-cdk-2.1019.1.tgz",
"integrity": "sha512-G2jxKuTsYTrYZX80CDApCrKcZ+AuFxxd+b0dkb0KEkfUsela7RqrDGLm5wOzSCIc3iH6GocR8JDVZuJ+0nNuKg==",
"version": "2.1018.1",
"resolved": "https://registry.npmjs.org/aws-cdk/-/aws-cdk-2.1018.1.tgz",
"integrity": "sha512-kFPRox5kSm+ktJ451o0ng9rD+60p5Kt1CZIWw8kXnvqbsxN2xv6qbmyWSXw7sGVXVwqrRKVj+71/JeDr+LMAZw==",
"dev": true,
"license": "Apache-2.0",
"bin": {

View File

@ -10,7 +10,7 @@
"node": ">=20"
},
"devDependencies": {
"aws-cdk": "^2.1019.1",
"aws-cdk": "^2.1018.1",
"cross-env": "^7.0.3"
}
}

View File

@ -7,6 +7,8 @@ from pathlib import Path
from tempfile import gettempdir
from typing import TYPE_CHECKING
from cryptography.hazmat.backends.openssl.backend import backend
from defusedxml import defuse_stdlib
from prometheus_client.values import MultiProcessValue
from authentik import get_full_version
@ -16,7 +18,6 @@ from authentik.lib.logging import get_logger_config
from authentik.lib.utils.http import get_http_session
from authentik.lib.utils.reflection import get_env
from authentik.root.install_id import get_install_id_raw
from authentik.root.setup import setup
from lifecycle.migrate import run_migrations
from lifecycle.wait_for_db import wait_for_db
from lifecycle.worker import DjangoUvicornWorker
@ -27,7 +28,10 @@ if TYPE_CHECKING:
from authentik.root.asgi import AuthentikAsgi
setup()
defuse_stdlib()
if CONFIG.get_bool("compliance.fips.enabled", False):
backend._enable_fips()
wait_for_db()

View File

@ -10,7 +10,7 @@ from typing import Any
from psycopg import Connection, Cursor, connect
from structlog.stdlib import get_logger
from authentik.lib.config import CONFIG, django_db_config
from authentik.lib.config import CONFIG
LOGGER = get_logger()
ADV_LOCK_UID = 1000
@ -115,13 +115,9 @@ def run_migrations():
execute_from_command_line(["", "migrate_schemas"])
if CONFIG.get_bool("tenants.enabled", False):
execute_from_command_line(["", "migrate_schemas", "--schema", "template", "--tenant"])
# Run django system checks for all databases
check_args = ["", "check"]
for label in django_db_config(CONFIG).keys():
check_args.append(f"--database={label}")
if not CONFIG.get_bool("debug"):
check_args.append("--deploy")
execute_from_command_line(check_args)
execute_from_command_line(
["", "check"] + ([] if CONFIG.get_bool("debug") else ["--deploy"])
)
finally:
release_lock(curr)
curr.close()

View File

@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2025-06-19 00:10+0000\n"
"POT-Creation-Date: 2025-06-04 00:12+0000\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
@ -763,12 +763,6 @@ msgid ""
"If left empty, Notification won't ben sent."
msgstr ""
#: authentik/events/models.py
msgid ""
"When enabled, notification will be sent to user the user that triggered the "
"event.When destination_group is configured, notification is sent to both."
msgstr ""
#: authentik/events/models.py
msgid "Notification Rule"
msgstr ""

View File

@ -1,18 +1,35 @@
#!/usr/bin/env python
"""Django manage.py"""
import os
import sys
import warnings
from authentik.lib.config import CONFIG
from cryptography.hazmat.backends.openssl.backend import backend
from defusedxml import defuse_stdlib
from django.utils.autoreload import DJANGO_AUTORELOAD_ENV
from authentik.root.setup import setup
from lifecycle.migrate import run_migrations
from lifecycle.wait_for_db import wait_for_db
setup()
warnings.filterwarnings("ignore", "SelectableGroups dict interface")
warnings.filterwarnings(
"ignore",
"defusedxml.lxml is no longer supported and will be removed in a future release.",
)
warnings.filterwarnings(
"ignore",
"defusedxml.cElementTree is deprecated, import from defusedxml.ElementTree instead.",
)
defuse_stdlib()
if CONFIG.get_bool("compliance.fips.enabled", False):
backend._enable_fips()
if __name__ == "__main__":
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
wait_for_db()
if (
len(sys.argv) > 1

View File

@ -1,10 +0,0 @@
# Prettier Ignorefile
## Static Files
**/LICENSE
## Build asset directories
coverage
dist
out
.docusaurus

View File

@ -4,5 +4,3 @@
export * from "./lib/theme.js";
export * from "./lib/common.js";
export * from "./lib/routing.js";
export * from "./lib/navbar.js";

Some files were not shown because too many files have changed in this diff Show More