Compare commits

..

1 Commits

Author SHA1 Message Date
b0e6558a4f current version
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2025-06-13 16:09:00 +02:00
408 changed files with 13849 additions and 13185 deletions

View File

@ -1,5 +1,5 @@
[bumpversion]
current_version = 2025.6.3
current_version = 2025.6.1
tag = True
commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
@ -21,8 +21,6 @@ optional_value = final
[bumpversion:file:package.json]
[bumpversion:file:package-lock.json]
[bumpversion:file:docker-compose.yml]
[bumpversion:file:schema.yml]
@ -33,4 +31,6 @@ optional_value = final
[bumpversion:file:internal/constants/constants.go]
[bumpversion:file:web/src/common/constants.ts]
[bumpversion:file:lifecycle/aws/template.yaml]

View File

@ -7,9 +7,6 @@ charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
[*.toml]
indent_size = 2
[*.html]
indent_size = 2

View File

@ -38,8 +38,6 @@ jobs:
# Needed for attestation
id-token: write
attestations: write
# Needed for checkout
contents: read
steps:
- uses: actions/checkout@v4
- uses: docker/setup-qemu-action@v3.6.0

View File

@ -9,15 +9,14 @@ on:
jobs:
test-container:
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
version:
- docs
- version-2025-4
- version-2025-2
- version-2024-12
steps:
- uses: actions/checkout@v4
- run: |

View File

@ -202,7 +202,7 @@ jobs:
uses: actions/cache@v4
with:
path: web/dist
key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b
key: ${{ runner.os }}-web-${{ hashFiles('web/package-lock.json', 'web/src/**', 'web/packages/sfe/src/**') }}-b
- name: prepare web ui
if: steps.cache-web.outputs.cache-hit != 'true'
working-directory: web
@ -247,13 +247,11 @@ jobs:
# Needed for attestation
id-token: write
attestations: write
# Needed for checkout
contents: read
needs: ci-core-mark
uses: ./.github/workflows/_reusable-docker-build.yaml
secrets: inherit
with:
image_name: ${{ github.repository == 'goauthentik/authentik-internal' && 'ghcr.io/goauthentik/internal-server' || 'ghcr.io/goauthentik/dev-server' }}
image_name: ghcr.io/goauthentik/dev-server
release: false
pr-comment:
needs:

View File

@ -59,7 +59,6 @@ jobs:
with:
jobs: ${{ toJSON(needs) }}
build-container:
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
timeout-minutes: 120
needs:
- ci-outpost-mark

View File

@ -41,29 +41,7 @@ jobs:
- name: test
working-directory: website/
run: npm test
build:
runs-on: ubuntu-latest
name: ${{ matrix.job }}
strategy:
fail-fast: false
matrix:
job:
- build
- build:integrations
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version-file: website/package.json
cache: "npm"
cache-dependency-path: website/package-lock.json
- working-directory: website/
run: npm ci
- name: build
working-directory: website/
run: npm run ${{ matrix.job }}
build-container:
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
runs-on: ubuntu-latest
permissions:
# Needed to upload container images to ghcr.io
@ -116,11 +94,9 @@ jobs:
needs:
- lint
- test
- build
- build-container
runs-on: ubuntu-latest
steps:
- uses: re-actors/alls-green@release/v1
with:
jobs: ${{ toJSON(needs) }}
allowed-skips: ${{ github.repository == 'goauthentik/authentik-internal' && 'build-container' || '[]' }}

View File

@ -2,7 +2,7 @@ name: "CodeQL"
on:
push:
branches: [main, next, version*]
branches: [main, "*", next, version*]
pull_request:
branches: [main]
schedule:

View File

@ -1,21 +0,0 @@
name: "authentik-repo-mirror-cleanup"
on:
workflow_dispatch:
jobs:
to_internal:
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- if: ${{ env.MIRROR_KEY != '' }}
uses: BeryJu/repository-mirroring-action@5cf300935bc2e068f73ea69bcc411a8a997208eb
with:
target_repo_url: git@github.com:goauthentik/authentik-internal.git
ssh_private_key: ${{ secrets.GH_MIRROR_KEY }}
args: --tags --force --prune
env:
MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }}

View File

@ -11,10 +11,11 @@ jobs:
with:
fetch-depth: 0
- if: ${{ env.MIRROR_KEY != '' }}
uses: BeryJu/repository-mirroring-action@5cf300935bc2e068f73ea69bcc411a8a997208eb
uses: pixta-dev/repository-mirroring-action@v1
with:
target_repo_url: git@github.com:goauthentik/authentik-internal.git
ssh_private_key: ${{ secrets.GH_MIRROR_KEY }}
args: --tags --force
target_repo_url:
git@github.com:goauthentik/authentik-internal.git
ssh_private_key:
${{ secrets.GH_MIRROR_KEY }}
env:
MIRROR_KEY: ${{ secrets.GH_MIRROR_KEY }}

View File

@ -16,7 +16,6 @@ env:
jobs:
compile:
if: ${{ github.repository != 'goauthentik/authentik-internal' }}
runs-on: ubuntu-latest
steps:
- id: generate_token

View File

@ -6,15 +6,13 @@
"!Context scalar",
"!Enumerate sequence",
"!Env scalar",
"!Env sequence",
"!Find sequence",
"!Format sequence",
"!If sequence",
"!Index scalar",
"!KeyOf scalar",
"!Value scalar",
"!AtIndex scalar",
"!ParseJSON scalar"
"!AtIndex scalar"
],
"typescript.preferences.importModuleSpecifier": "non-relative",
"typescript.preferences.importModuleSpecifierEnding": "index",

View File

@ -75,9 +75,9 @@ RUN --mount=type=secret,id=GEOIPUPDATE_ACCOUNT_ID \
/bin/sh -c "GEOIPUPDATE_LICENSE_KEY_FILE=/run/secrets/GEOIPUPDATE_LICENSE_KEY /usr/bin/entry.sh || echo 'Failed to get GeoIP database, disabling'; exit 0"
# Stage 4: Download uv
FROM ghcr.io/astral-sh/uv:0.7.17 AS uv
FROM ghcr.io/astral-sh/uv:0.7.13 AS uv
# Stage 5: Base python image
FROM ghcr.io/goauthentik/fips-python:3.13.5-slim-bookworm-fips AS python-base
FROM ghcr.io/goauthentik/fips-python:3.13.4-slim-bookworm-fips AS python-base
ENV VENV_PATH="/ak-root/.venv" \
PATH="/lifecycle:/ak-root/.venv/bin:$PATH" \

View File

@ -86,10 +86,6 @@ dev-create-db:
dev-reset: dev-drop-db dev-create-db migrate ## Drop and restore the Authentik PostgreSQL instance to a "fresh install" state.
update-test-mmdb: ## Update test GeoIP and ASN Databases
curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-ASN-Test.mmdb -o ${PWD}/tests/GeoLite2-ASN-Test.mmdb
curl -L https://raw.githubusercontent.com/maxmind/MaxMind-DB/refs/heads/main/test-data/GeoLite2-City-Test.mmdb -o ${PWD}/tests/GeoLite2-City-Test.mmdb
#########################
## API Schema
#########################
@ -98,7 +94,7 @@ gen-build: ## Extract the schema from the database
AUTHENTIK_DEBUG=true \
AUTHENTIK_TENANTS__ENABLED=true \
AUTHENTIK_OUTPOSTS__DISABLE_EMBEDDED_OUTPOST=true \
uv run ak make_blueprint_schema --file blueprints/schema.json
uv run ak make_blueprint_schema > blueprints/schema.json
AUTHENTIK_DEBUG=true \
AUTHENTIK_TENANTS__ENABLED=true \
AUTHENTIK_OUTPOSTS__DISABLE_EMBEDDED_OUTPOST=true \
@ -150,9 +146,9 @@ gen-client-ts: gen-clean-ts ## Build and install the authentik API for Typescri
--additional-properties=npmVersion=${NPM_VERSION} \
--git-repo-id authentik \
--git-user-id goauthentik
cd ${PWD}/${GEN_API_TS} && npm link
cd ${PWD}/web && npm link @goauthentik/api
mkdir -p web/node_modules/@goauthentik/api
cd ${PWD}/${GEN_API_TS} && npm i
\cp -rf ${PWD}/${GEN_API_TS}/* web/node_modules/@goauthentik/api
gen-client-py: gen-clean-py ## Build and install the authentik API for Python
docker run \

View File

@ -2,7 +2,7 @@
from os import environ
__version__ = "2025.6.3"
__version__ = "2025.6.1"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View File

@ -72,33 +72,20 @@ class Command(BaseCommand):
"additionalProperties": True,
},
"entries": {
"anyOf": [
{
"type": "array",
"items": {"$ref": "#/$defs/blueprint_entry"},
},
{
"type": "object",
"additionalProperties": {
"type": "array",
"items": {"$ref": "#/$defs/blueprint_entry"},
},
},
],
"type": "array",
"items": {
"oneOf": [],
},
},
},
"$defs": {"blueprint_entry": {"oneOf": []}},
"$defs": {},
}
def add_arguments(self, parser):
parser.add_argument("--file", type=str)
@no_translations
def handle(self, *args, file: str, **options):
def handle(self, *args, **options):
"""Generate JSON Schema for blueprints"""
self.build()
with open(file, "w") as _schema:
_schema.write(dumps(self.schema, indent=4, default=Command.json_default))
self.stdout.write(dumps(self.schema, indent=4, default=Command.json_default))
@staticmethod
def json_default(value: Any) -> Any:
@ -125,7 +112,7 @@ class Command(BaseCommand):
}
)
model_path = f"{model._meta.app_label}.{model._meta.model_name}"
self.schema["$defs"]["blueprint_entry"]["oneOf"].append(
self.schema["properties"]["entries"]["items"]["oneOf"].append(
self.template_entry(model_path, model, serializer)
)

View File

@ -1,11 +1,10 @@
version: 1
entries:
foo:
- identifiers:
name: "%(id)s"
slug: "%(id)s"
model: authentik_flows.flow
state: present
attrs:
designation: stage_configuration
title: foo
- identifiers:
name: "%(id)s"
slug: "%(id)s"
model: authentik_flows.flow
state: present
attrs:
designation: stage_configuration
title: foo

View File

@ -37,7 +37,6 @@ entries:
- attrs:
attributes:
env_null: !Env [bar-baz, null]
json_parse: !ParseJSON '{"foo": "bar"}'
policy_pk1:
!Format [
"%s-%s",

View File

@ -35,6 +35,6 @@ def blueprint_tester(file_name: Path) -> Callable:
for blueprint_file in Path("blueprints/").glob("**/*.yaml"):
if "local" in str(blueprint_file) or "testing" in str(blueprint_file):
if "local" in str(blueprint_file):
continue
setattr(TestPackaged, f"test_blueprint_{blueprint_file}", blueprint_tester(blueprint_file))

View File

@ -5,6 +5,7 @@ from collections.abc import Callable
from django.apps import apps
from django.test import TestCase
from authentik.blueprints.v1.importer import is_model_allowed
from authentik.lib.models import SerializerModel
from authentik.providers.oauth2.models import RefreshToken
@ -21,13 +22,10 @@ def serializer_tester_factory(test_model: type[SerializerModel]) -> Callable:
return
model_class = test_model()
self.assertTrue(isinstance(model_class, SerializerModel))
# Models that have subclasses don't have to have a serializer
if len(test_model.__subclasses__()) > 0:
return
self.assertIsNotNone(model_class.serializer)
if model_class.serializer.Meta().model == RefreshToken:
return
self.assertTrue(issubclass(test_model, model_class.serializer.Meta().model))
self.assertEqual(model_class.serializer.Meta().model, test_model)
return tester
@ -36,6 +34,6 @@ for app in apps.get_app_configs():
if not app.label.startswith("authentik"):
continue
for model in app.get_models():
if not issubclass(model, SerializerModel):
if not is_model_allowed(model):
continue
setattr(TestModels, f"test_{app.label}_{model.__name__}", serializer_tester_factory(model))

View File

@ -215,7 +215,6 @@ class TestBlueprintsV1(TransactionTestCase):
},
"nested_context": "context-nested-value",
"env_null": None,
"json_parse": {"foo": "bar"},
"at_index_sequence": "foo",
"at_index_sequence_default": "non existent",
"at_index_mapping": 2,

View File

@ -6,7 +6,6 @@ from copy import copy
from dataclasses import asdict, dataclass, field, is_dataclass
from enum import Enum
from functools import reduce
from json import JSONDecodeError, loads
from operator import ixor
from os import getenv
from typing import Any, Literal, Union
@ -192,18 +191,11 @@ class Blueprint:
"""Dataclass used for a full export"""
version: int = field(default=1)
entries: list[BlueprintEntry] | dict[str, list[BlueprintEntry]] = field(default_factory=list)
entries: list[BlueprintEntry] = field(default_factory=list)
context: dict = field(default_factory=dict)
metadata: BlueprintMetadata | None = field(default=None)
def iter_entries(self) -> Iterable[BlueprintEntry]:
if isinstance(self.entries, dict):
for _section, entries in self.entries.items():
yield from entries
else:
yield from self.entries
class YAMLTag:
"""Base class for all YAML Tags"""
@ -234,7 +226,7 @@ class KeyOf(YAMLTag):
self.id_from = node.value
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
for _entry in blueprint.iter_entries():
for _entry in blueprint.entries:
if _entry.id == self.id_from and _entry._state.instance:
# Special handling for PolicyBindingModels, as they'll have a different PK
# which is used when creating policy bindings
@ -292,22 +284,6 @@ class Context(YAMLTag):
return value
class ParseJSON(YAMLTag):
"""Parse JSON from context/env/etc value"""
raw: str
def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None:
super().__init__()
self.raw = node.value
def resolve(self, entry: BlueprintEntry, blueprint: Blueprint) -> Any:
try:
return loads(self.raw)
except JSONDecodeError as exc:
raise EntryInvalidError.from_entry(exc, entry) from exc
class Format(YAMLTag):
"""Format a string"""
@ -683,7 +659,6 @@ class BlueprintLoader(SafeLoader):
self.add_constructor("!Value", Value)
self.add_constructor("!Index", Index)
self.add_constructor("!AtIndex", AtIndex)
self.add_constructor("!ParseJSON", ParseJSON)
class EntryInvalidError(SentryIgnoredException):

View File

@ -384,7 +384,7 @@ class Importer:
def _apply_models(self, raise_errors=False) -> bool:
"""Apply (create/update) models yaml"""
self.__pk_map = {}
for entry in self._import.iter_entries():
for entry in self._import.entries:
model_app_label, model_name = entry.get_model(self._import).split(".")
try:
model: type[SerializerModel] = registry.get_model(model_app_label, model_name)

View File

@ -1,6 +1,8 @@
"""Authenticator Devices API Views"""
from drf_spectacular.utils import extend_schema
from django.utils.translation import gettext_lazy as _
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema
from guardian.shortcuts import get_objects_for_user
from rest_framework.fields import (
BooleanField,
@ -13,7 +15,6 @@ from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ViewSet
from authentik.core.api.users import ParamUserSerializer
from authentik.core.api.utils import MetaNameSerializer
from authentik.enterprise.stages.authenticator_endpoint_gdtc.models import EndpointDevice
from authentik.stages.authenticator import device_classes, devices_for_user
@ -22,7 +23,7 @@ from authentik.stages.authenticator_webauthn.models import WebAuthnDevice
class DeviceSerializer(MetaNameSerializer):
"""Serializer for authenticator devices"""
"""Serializer for Duo authenticator devices"""
pk = CharField()
name = CharField()
@ -32,27 +33,22 @@ class DeviceSerializer(MetaNameSerializer):
last_updated = DateTimeField(read_only=True)
last_used = DateTimeField(read_only=True, allow_null=True)
extra_description = SerializerMethodField()
external_id = SerializerMethodField()
def get_type(self, instance: Device) -> str:
"""Get type of device"""
return instance._meta.label
def get_extra_description(self, instance: Device) -> str | None:
def get_extra_description(self, instance: Device) -> str:
"""Get extra description"""
if isinstance(instance, WebAuthnDevice):
return instance.device_type.description if instance.device_type else None
return (
instance.device_type.description
if instance.device_type
else _("Extra description not available")
)
if isinstance(instance, EndpointDevice):
return instance.data.get("deviceSignals", {}).get("deviceModel")
return None
def get_external_id(self, instance: Device) -> str | None:
"""Get external Device ID"""
if isinstance(instance, WebAuthnDevice):
return instance.device_type.aaguid if instance.device_type else None
if isinstance(instance, EndpointDevice):
return instance.data.get("deviceSignals", {}).get("deviceModel")
return None
return ""
class DeviceViewSet(ViewSet):
@ -61,6 +57,7 @@ class DeviceViewSet(ViewSet):
serializer_class = DeviceSerializer
permission_classes = [IsAuthenticated]
@extend_schema(responses={200: DeviceSerializer(many=True)})
def list(self, request: Request) -> Response:
"""Get all devices for current user"""
devices = devices_for_user(request.user)
@ -82,11 +79,18 @@ class AdminDeviceViewSet(ViewSet):
yield from device_set
@extend_schema(
parameters=[ParamUserSerializer],
parameters=[
OpenApiParameter(
name="user",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.INT,
)
],
responses={200: DeviceSerializer(many=True)},
)
def list(self, request: Request) -> Response:
"""Get all devices for current user"""
args = ParamUserSerializer(data=request.query_params)
args.is_valid(raise_exception=True)
return Response(DeviceSerializer(self.get_devices(**args.validated_data), many=True).data)
kwargs = {}
if "user" in request.query_params:
kwargs = {"user": request.query_params["user"]}
return Response(DeviceSerializer(self.get_devices(**kwargs), many=True).data)

View File

@ -90,12 +90,6 @@ from authentik.stages.email.utils import TemplateEmailMessage
LOGGER = get_logger()
class ParamUserSerializer(PassiveSerializer):
"""Partial serializer for query parameters to select a user"""
user = PrimaryKeyRelatedField(queryset=User.objects.all().exclude_anonymous(), required=False)
class UserGroupSerializer(ModelSerializer):
"""Simplified Group Serializer for user's groups"""
@ -392,23 +386,8 @@ class UserViewSet(UsedByMixin, ModelViewSet):
queryset = User.objects.none()
ordering = ["username"]
serializer_class = UserSerializer
filterset_class = UsersFilter
search_fields = ["username", "name", "is_active", "email", "uuid", "attributes"]
def get_ql_fields(self):
from djangoql.schema import BoolField, StrField
from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
return [
StrField(User, "username"),
StrField(User, "name"),
StrField(User, "email"),
StrField(User, "path"),
BoolField(User, "is_active", nullable=True),
ChoiceSearchField(User, "type"),
JSONSearchField(User, "attributes", suggest_nested=False),
]
filterset_class = UsersFilter
def get_queryset(self):
base_qs = User.objects.all().exclude_anonymous()

View File

@ -2,7 +2,6 @@
from typing import Any
from django.db import models
from django.db.models import Model
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
from drf_spectacular.plumbing import build_basic_type
@ -31,27 +30,7 @@ def is_dict(value: Any):
raise ValidationError("Value must be a dictionary, and not have any duplicate keys.")
class JSONDictField(JSONField):
"""JSON Field which only allows dictionaries"""
default_validators = [is_dict]
class JSONExtension(OpenApiSerializerFieldExtension):
"""Generate API Schema for JSON fields as"""
target_class = "authentik.core.api.utils.JSONDictField"
def map_serializer_field(self, auto_schema, direction):
return build_basic_type(OpenApiTypes.OBJECT)
class ModelSerializer(BaseModelSerializer):
# By default, JSON fields we have are used to store dictionaries
serializer_field_mapping = BaseModelSerializer.serializer_field_mapping.copy()
serializer_field_mapping[models.JSONField] = JSONDictField
def create(self, validated_data):
instance = super().create(validated_data)
@ -92,6 +71,21 @@ class ModelSerializer(BaseModelSerializer):
return instance
class JSONDictField(JSONField):
"""JSON Field which only allows dictionaries"""
default_validators = [is_dict]
class JSONExtension(OpenApiSerializerFieldExtension):
"""Generate API Schema for JSON fields as"""
target_class = "authentik.core.api.utils.JSONDictField"
def map_serializer_field(self, auto_schema, direction):
return build_basic_type(OpenApiTypes.OBJECT)
class PassiveSerializer(Serializer):
"""Base serializer class which doesn't implement create/update methods"""

View File

@ -13,6 +13,7 @@ class Command(TenantCommand):
parser.add_argument("usernames", nargs="*", type=str)
def handle_per_tenant(self, **options):
print(options)
new_type = UserTypes(options["type"])
qs = (
User.objects.exclude_anonymous()

View File

@ -18,7 +18,7 @@ from django.http import HttpRequest
from django.utils.functional import SimpleLazyObject, cached_property
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from django_cte import CTE, with_cte
from django_cte import CTEQuerySet, With
from guardian.conf import settings
from guardian.mixins import GuardianUserMixin
from model_utils.managers import InheritanceManager
@ -136,7 +136,7 @@ class AttributesMixin(models.Model):
return instance, False
class GroupQuerySet(QuerySet):
class GroupQuerySet(CTEQuerySet):
def with_children_recursive(self):
"""Recursively get all groups that have the current queryset as parents
or are indirectly related."""
@ -165,9 +165,9 @@ class GroupQuerySet(QuerySet):
)
# Build the recursive query, see above
cte = CTE.recursive(make_cte)
cte = With.recursive(make_cte)
# Return the result, as a usable queryset for Group.
return with_cte(cte, select=cte.join(Group, group_uuid=cte.col.group_uuid))
return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte)
class Group(SerializerModel, AttributesMixin):
@ -1082,12 +1082,6 @@ class AuthenticatedSession(SerializerModel):
user = models.ForeignKey(User, on_delete=models.CASCADE)
@property
def serializer(self) -> type[Serializer]:
from authentik.core.api.authenticated_sessions import AuthenticatedSessionSerializer
return AuthenticatedSessionSerializer
class Meta:
verbose_name = _("Authenticated Session")
verbose_name_plural = _("Authenticated Sessions")

View File

@ -114,7 +114,6 @@ class TestApplicationsAPI(APITestCase):
self.assertJSONEqual(
response.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,
@ -168,7 +167,6 @@ class TestApplicationsAPI(APITestCase):
self.assertJSONEqual(
response.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,

View File

@ -119,17 +119,17 @@ class TestTrimPasswordHistory(TestCase):
[
UserPasswordHistory(
user=self.user,
old_password="hunter1", # nosec
old_password="hunter1", # nosec B106
created_at=_now - timedelta(days=3),
),
UserPasswordHistory(
user=self.user,
old_password="hunter2", # nosec
old_password="hunter2", # nosec B106
created_at=_now - timedelta(days=2),
),
UserPasswordHistory(
user=self.user,
old_password="hunter3", # nosec
old_password="hunter3", # nosec B106
created_at=_now,
),
]

View File

@ -1,8 +1,10 @@
from hashlib import sha256
from django.contrib.auth.signals import user_logged_out
from django.db.models import Model
from django.db.models.signals import post_delete, post_save, pre_delete
from django.dispatch import receiver
from django.http.request import HttpRequest
from guardian.shortcuts import assign_perm
from authentik.core.models import (
@ -60,6 +62,31 @@ def ssf_providers_post_save(sender: type[Model], instance: SSFProvider, created:
instance.save()
@receiver(user_logged_out)
def ssf_user_logged_out_session_revoked(sender, request: HttpRequest, user: User, **_):
"""Session revoked trigger (user logged out)"""
if not request.session or not request.session.session_key or not user:
return
send_ssf_event(
EventTypes.CAEP_SESSION_REVOKED,
{
"initiating_entity": "user",
},
sub_id={
"format": "complex",
"session": {
"format": "opaque",
"id": sha256(request.session.session_key.encode("ascii")).hexdigest(),
},
"user": {
"format": "email",
"email": user.email,
},
},
request=request,
)
@receiver(pre_delete, sender=AuthenticatedSession)
def ssf_user_session_delete_session_revoked(sender, instance: AuthenticatedSession, **_):
"""Session revoked trigger (users' session has been deleted)

View File

@ -1,12 +0,0 @@
"""Enterprise app config"""
from authentik.enterprise.apps import EnterpriseConfig
class AuthentikEnterpriseSearchConfig(EnterpriseConfig):
"""Enterprise app config"""
name = "authentik.enterprise.search"
label = "authentik_search"
verbose_name = "authentik Enterprise.Search"
default = True

View File

@ -1,128 +0,0 @@
"""DjangoQL search"""
from collections import OrderedDict, defaultdict
from collections.abc import Generator
from django.db import connection
from django.db.models import Model, Q
from djangoql.compat import text_type
from djangoql.schema import StrField
class JSONSearchField(StrField):
"""JSON field for DjangoQL"""
model: Model
def __init__(self, model=None, name=None, nullable=None, suggest_nested=True):
# Set this in the constructor to not clobber the type variable
self.type = "relation"
self.suggest_nested = suggest_nested
super().__init__(model, name, nullable)
def get_lookup(self, path, operator, value):
search = "__".join(path)
op, invert = self.get_operator(operator)
q = Q(**{f"{search}{op}": self.get_lookup_value(value)})
return ~q if invert else q
def json_field_keys(self) -> Generator[tuple[str]]:
with connection.cursor() as cursor:
cursor.execute(
f"""
WITH RECURSIVE "{self.name}_keys" AS (
SELECT
ARRAY[jsonb_object_keys("{self.name}")] AS key_path_array,
"{self.name}" -> jsonb_object_keys("{self.name}") AS value
FROM {self.model._meta.db_table}
WHERE "{self.name}" IS NOT NULL
AND jsonb_typeof("{self.name}") = 'object'
UNION ALL
SELECT
ck.key_path_array || jsonb_object_keys(ck.value),
ck.value -> jsonb_object_keys(ck.value) AS value
FROM "{self.name}_keys" ck
WHERE jsonb_typeof(ck.value) = 'object'
),
unique_paths AS (
SELECT DISTINCT key_path_array
FROM "{self.name}_keys"
)
SELECT key_path_array FROM unique_paths;
""" # nosec
)
return (x[0] for x in cursor.fetchall())
def get_nested_options(self) -> OrderedDict:
"""Get keys of all nested objects to show autocomplete"""
if not self.suggest_nested:
return OrderedDict()
base_model_name = f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}"
def recursive_function(parts: list[str], parent_parts: list[str] | None = None):
if not parent_parts:
parent_parts = []
path = parts.pop(0)
parent_parts.append(path)
relation_key = "_".join(parent_parts)
if len(parts) > 1:
out_dict = {
relation_key: {
parts[0]: {
"type": "relation",
"relation": f"{relation_key}_{parts[0]}",
}
}
}
child_paths = recursive_function(parts.copy(), parent_parts.copy())
child_paths.update(out_dict)
return child_paths
else:
return {relation_key: {parts[0]: {}}}
relation_structure = defaultdict(dict)
for relations in self.json_field_keys():
result = recursive_function([base_model_name] + relations)
for relation_key, value in result.items():
for sub_relation_key, sub_value in value.items():
if not relation_structure[relation_key].get(sub_relation_key, None):
relation_structure[relation_key][sub_relation_key] = sub_value
else:
relation_structure[relation_key][sub_relation_key].update(sub_value)
final_dict = defaultdict(dict)
for key, value in relation_structure.items():
for sub_key, sub_value in value.items():
if not sub_value:
final_dict[key][sub_key] = {
"type": "str",
"nullable": True,
}
else:
final_dict[key][sub_key] = sub_value
return OrderedDict(final_dict)
def relation(self) -> str:
return f"{self.model._meta.app_label}.{self.model._meta.model_name}_{self.name}"
class ChoiceSearchField(StrField):
def __init__(self, model=None, name=None, nullable=None):
super().__init__(model, name, nullable, suggest_options=True)
def get_options(self, search):
result = []
choices = self._field_choices()
if choices:
search = search.lower()
for c in choices:
choice = text_type(c[0])
if search in choice.lower():
result.append(choice)
return result

View File

@ -1,53 +0,0 @@
from rest_framework.response import Response
from authentik.api.pagination import Pagination
from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, QLSearch
class AutocompletePagination(Pagination):
def paginate_queryset(self, queryset, request, view=None):
self.view = view
return super().paginate_queryset(queryset, request, view)
def get_autocomplete(self):
schema = QLSearch().get_schema(self.request, self.view)
introspections = {}
if hasattr(self.view, "get_ql_fields"):
from authentik.enterprise.search.schema import AKQLSchemaSerializer
introspections = AKQLSchemaSerializer().serialize(
schema(self.page.paginator.object_list.model)
)
return introspections
def get_paginated_response(self, data):
previous_page_number = 0
if self.page.has_previous():
previous_page_number = self.page.previous_page_number()
next_page_number = 0
if self.page.has_next():
next_page_number = self.page.next_page_number()
return Response(
{
"pagination": {
"next": next_page_number,
"previous": previous_page_number,
"count": self.page.paginator.count,
"current": self.page.number,
"total_pages": self.page.paginator.num_pages,
"start_index": self.page.start_index(),
"end_index": self.page.end_index(),
},
"results": data,
"autocomplete": self.get_autocomplete(),
}
)
def get_paginated_response_schema(self, schema):
final_schema = super().get_paginated_response_schema(schema)
final_schema["properties"]["autocomplete"] = {
"$ref": f"#/components/schemas/{AUTOCOMPLETE_COMPONENT_NAME}"
}
final_schema["required"].append("autocomplete")
return final_schema

View File

@ -1,80 +0,0 @@
"""DjangoQL search"""
from django.apps import apps
from django.db.models import QuerySet
from djangoql.ast import Name
from djangoql.exceptions import DjangoQLError
from djangoql.queryset import apply_search
from djangoql.schema import DjangoQLSchema
from rest_framework.filters import BaseFilterBackend, SearchFilter
from rest_framework.request import Request
from structlog.stdlib import get_logger
from authentik.enterprise.search.fields import JSONSearchField
LOGGER = get_logger()
AUTOCOMPLETE_COMPONENT_NAME = "Autocomplete"
AUTOCOMPLETE_SCHEMA = {
"type": "object",
"additionalProperties": {},
}
class BaseSchema(DjangoQLSchema):
"""Base Schema which deals with JSON Fields"""
def resolve_name(self, name: Name):
model = self.model_label(self.current_model)
root_field = name.parts[0]
field = self.models[model].get(root_field)
# If the query goes into a JSON field, return the root
# field as the JSON field will do the rest
if isinstance(field, JSONSearchField):
# This is a workaround; build_filter will remove the right-most
# entry in the path as that is intended to be the same as the field
# however for JSON that is not the case
if name.parts[-1] != root_field:
name.parts.append(root_field)
return field
return super().resolve_name(name)
class QLSearch(BaseFilterBackend):
"""rest_framework search filter which uses DjangoQL"""
def __init__(self):
super().__init__()
self._fallback = SearchFilter()
@property
def enabled(self):
return apps.get_app_config("authentik_enterprise").enabled()
def get_search_terms(self, request: Request) -> str:
"""Search terms are set by a ?search=... query parameter,
and may be comma and/or whitespace delimited."""
params = request.query_params.get("search", "")
params = params.replace("\x00", "") # strip null characters
return params
def get_schema(self, request: Request, view) -> BaseSchema:
ql_fields = []
if hasattr(view, "get_ql_fields"):
ql_fields = view.get_ql_fields()
class InlineSchema(BaseSchema):
def get_fields(self, model):
return ql_fields or []
return InlineSchema
def filter_queryset(self, request: Request, queryset: QuerySet, view) -> QuerySet:
search_query = self.get_search_terms(request)
schema = self.get_schema(request, view)
if len(search_query) == 0 or not self.enabled:
return self._fallback.filter_queryset(request, queryset, view)
try:
return apply_search(queryset, search_query, schema=schema)
except DjangoQLError as exc:
LOGGER.debug("Failed to parse search expression", exc=exc)
return self._fallback.filter_queryset(request, queryset, view)

View File

@ -1,29 +0,0 @@
from djangoql.serializers import DjangoQLSchemaSerializer
from drf_spectacular.generators import SchemaGenerator
from authentik.api.schema import create_component
from authentik.enterprise.search.fields import JSONSearchField
from authentik.enterprise.search.ql import AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA
class AKQLSchemaSerializer(DjangoQLSchemaSerializer):
def serialize(self, schema):
serialization = super().serialize(schema)
for _, fields in schema.models.items():
for _, field in fields.items():
if not isinstance(field, JSONSearchField):
continue
serialization["models"].update(field.get_nested_options())
return serialization
def serialize_field(self, field):
result = super().serialize_field(field)
if isinstance(field, JSONSearchField):
result["relation"] = field.relation()
return result
def postprocess_schema_search_autocomplete(result, generator: SchemaGenerator, **kwargs):
create_component(generator, AUTOCOMPLETE_COMPONENT_NAME, AUTOCOMPLETE_SCHEMA)
return result

View File

@ -1,17 +0,0 @@
SPECTACULAR_SETTINGS = {
"POSTPROCESSING_HOOKS": [
"authentik.api.schema.postprocess_schema_responses",
"authentik.enterprise.search.schema.postprocess_schema_search_autocomplete",
"drf_spectacular.hooks.postprocess_schema_enums",
],
}
REST_FRAMEWORK = {
"DEFAULT_PAGINATION_CLASS": "authentik.enterprise.search.pagination.AutocompletePagination",
"DEFAULT_FILTER_BACKENDS": [
"authentik.enterprise.search.ql.QLSearch",
"authentik.rbac.filters.ObjectFilter",
"django_filters.rest_framework.DjangoFilterBackend",
"rest_framework.filters.OrderingFilter",
],
}

View File

@ -1,78 +0,0 @@
from json import loads
from unittest.mock import PropertyMock, patch
from urllib.parse import urlencode
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.tests.utils import create_test_admin_user
@patch(
"authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
PropertyMock(return_value=True),
)
class QLTest(APITestCase):
def setUp(self):
self.user = create_test_admin_user()
# ensure we have more than 1 user
create_test_admin_user()
def test_search(self):
"""Test simple search query"""
self.client.force_login(self.user)
query = f'username = "{self.user.username}"'
res = self.client.get(
reverse(
"authentik_api:user-list",
)
+ f"?{urlencode({"search": query})}"
)
self.assertEqual(res.status_code, 200)
content = loads(res.content)
self.assertEqual(content["pagination"]["count"], 1)
self.assertEqual(content["results"][0]["username"], self.user.username)
def test_no_search(self):
"""Ensure works with no search query"""
self.client.force_login(self.user)
res = self.client.get(
reverse(
"authentik_api:user-list",
)
)
self.assertEqual(res.status_code, 200)
content = loads(res.content)
self.assertNotEqual(content["pagination"]["count"], 1)
def test_search_no_ql(self):
"""Test simple search query (no QL)"""
self.client.force_login(self.user)
res = self.client.get(
reverse(
"authentik_api:user-list",
)
+ f"?{urlencode({"search": self.user.username})}"
)
self.assertEqual(res.status_code, 200)
content = loads(res.content)
self.assertEqual(content["pagination"]["count"], 1)
self.assertEqual(content["results"][0]["username"], self.user.username)
def test_search_json(self):
"""Test search query with a JSON attribute"""
self.user.attributes = {"foo": {"bar": "baz"}}
self.user.save()
self.client.force_login(self.user)
query = 'attributes.foo.bar = "baz"'
res = self.client.get(
reverse(
"authentik_api:user-list",
)
+ f"?{urlencode({"search": query})}"
)
self.assertEqual(res.status_code, 200)
content = loads(res.content)
self.assertEqual(content["pagination"]["count"], 1)
self.assertEqual(content["results"][0]["username"], self.user.username)

View File

@ -18,7 +18,6 @@ TENANT_APPS = [
"authentik.enterprise.providers.google_workspace",
"authentik.enterprise.providers.microsoft_entra",
"authentik.enterprise.providers.ssf",
"authentik.enterprise.search",
"authentik.enterprise.stages.authenticator_endpoint_gdtc",
"authentik.enterprise.stages.mtls",
"authentik.enterprise.stages.source",

View File

@ -97,7 +97,6 @@ class SourceStageFinal(StageView):
token: FlowToken = self.request.session.get(SESSION_KEY_OVERRIDE_FLOW_TOKEN)
self.logger.info("Replacing source flow with overridden flow", flow=token.flow.slug)
plan = token.plan
plan.context.update(self.executor.plan.context)
plan.context[PLAN_CONTEXT_IS_RESTORED] = token
response = plan.to_redirect(self.request, token.flow)
token.delete()

View File

@ -90,17 +90,14 @@ class TestSourceStage(FlowTestCase):
plan: FlowPlan = session[SESSION_KEY_PLAN]
plan.insert_stage(in_memory_stage(SourceStageFinal), index=0)
plan.context[PLAN_CONTEXT_IS_RESTORED] = flow_token
plan.context["foo"] = "bar"
session[SESSION_KEY_PLAN] = plan
session.save()
# Pretend we've just returned from the source
with self.assertFlowFinishes() as ff:
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True
)
self.assertEqual(response.status_code, 200)
self.assertStageRedirects(
response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
)
self.assertEqual(ff().context["foo"], "bar")
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), follow=True
)
self.assertEqual(response.status_code, 200)
self.assertStageRedirects(
response, reverse("authentik_core:if-flow", kwargs={"flow_slug": flow.slug})
)

View File

@ -132,22 +132,6 @@ class EventViewSet(ModelViewSet):
]
filterset_class = EventsFilter
def get_ql_fields(self):
from djangoql.schema import DateTimeField, StrField
from authentik.enterprise.search.fields import ChoiceSearchField, JSONSearchField
return [
ChoiceSearchField(Event, "action"),
StrField(Event, "event_uuid"),
StrField(Event, "app", suggest_options=True),
StrField(Event, "client_ip"),
JSONSearchField(Event, "user", suggest_nested=False),
JSONSearchField(Event, "brand", suggest_nested=False),
JSONSearchField(Event, "context", suggest_nested=False),
DateTimeField(Event, "created", suggest_options=True),
]
@extend_schema(
methods=["GET"],
responses={200: EventTopPerUserSerializer(many=True)},

View File

@ -11,7 +11,7 @@ from authentik.events.models import NotificationRule
class NotificationRuleSerializer(ModelSerializer):
"""NotificationRule Serializer"""
destination_group_obj = GroupSerializer(read_only=True, source="destination_group")
group_obj = GroupSerializer(read_only=True, source="group")
class Meta:
model = NotificationRule
@ -20,9 +20,8 @@ class NotificationRuleSerializer(ModelSerializer):
"name",
"transports",
"severity",
"destination_group",
"destination_group_obj",
"destination_event_user",
"group",
"group_obj",
]
@ -31,6 +30,6 @@ class NotificationRuleViewSet(UsedByMixin, ModelViewSet):
queryset = NotificationRule.objects.all()
serializer_class = NotificationRuleSerializer
filterset_fields = ["name", "severity", "destination_group__name"]
filterset_fields = ["name", "severity", "group__name"]
ordering = ["name"]
search_fields = ["name", "destination_group__name"]
search_fields = ["name", "group__name"]

View File

@ -15,13 +15,13 @@ class MMDBContextProcessor(EventContextProcessor):
self.reader: Reader | None = None
self._last_mtime: float = 0.0
self.logger = get_logger()
self.load()
self.open()
def path(self) -> str | None:
"""Get the path to the MMDB file to load"""
raise NotImplementedError
def load(self):
def open(self):
"""Get GeoIP Reader, if configured, otherwise none"""
path = self.path()
if path == "" or not path:
@ -44,7 +44,7 @@ class MMDBContextProcessor(EventContextProcessor):
diff = self._last_mtime < mtime
if diff > 0:
self.logger.info("Found new MMDB Database, reopening", diff=diff, path=path)
self.load()
self.open()
except OSError as exc:
self.logger.warning("Failed to check MMDB age", exc=exc)

View File

@ -19,7 +19,7 @@ from authentik.blueprints.v1.importer import excluded_models
from authentik.core.models import Group, User
from authentik.events.models import Event, EventAction, Notification
from authentik.events.utils import model_to_dict
from authentik.lib.sentry import should_ignore_exception
from authentik.lib.sentry import before_send
from authentik.lib.utils.errors import exception_to_string
from authentik.stages.authenticator_static.models import StaticToken
@ -173,7 +173,7 @@ class AuditMiddleware:
message=exception_to_string(exception),
)
thread.run()
elif not should_ignore_exception(exception):
elif before_send({}, {"exc_info": (None, exception, None)}) is not None:
thread = EventNewThread(
EventAction.SYSTEM_EXCEPTION,
request,

View File

@ -1,26 +0,0 @@
# Generated by Django 5.1.11 on 2025-06-16 23:21
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_events", "0009_remove_notificationtransport_webhook_mapping_and_more"),
]
operations = [
migrations.RenameField(
model_name="notificationrule",
old_name="group",
new_name="destination_group",
),
migrations.AddField(
model_name="notificationrule",
name="destination_event_user",
field=models.BooleanField(
default=False,
help_text="When enabled, notification will be sent to user the user that triggered the event.When destination_group is configured, notification is sent to both.",
),
),
]

View File

@ -1,12 +1,10 @@
"""authentik events models"""
from collections.abc import Generator
from datetime import timedelta
from difflib import get_close_matches
from functools import lru_cache
from inspect import currentframe
from smtplib import SMTPException
from typing import Any
from uuid import uuid4
from django.apps import apps
@ -193,32 +191,17 @@ class Event(SerializerModel, ExpiringModel):
brand: Brand = request.brand
self.brand = sanitize_dict(model_to_dict(brand))
if hasattr(request, "user"):
self.user = get_user(request.user)
original_user = None
if hasattr(request, "session"):
original_user = request.session.get(SESSION_KEY_IMPERSONATE_ORIGINAL_USER, None)
self.user = get_user(request.user, original_user)
if user:
self.user = get_user(user)
# Check if we're currently impersonating, and add that user
if hasattr(request, "session"):
from authentik.flows.views.executor import SESSION_KEY_PLAN
# Check if we're currently impersonating, and add that user
if SESSION_KEY_IMPERSONATE_ORIGINAL_USER in request.session:
self.user = get_user(request.session[SESSION_KEY_IMPERSONATE_ORIGINAL_USER])
self.user["on_behalf_of"] = get_user(request.session[SESSION_KEY_IMPERSONATE_USER])
# Special case for events that happen during a flow, the user might not be authenticated
# yet but is a pending user instead
if SESSION_KEY_PLAN in request.session:
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
plan: FlowPlan = request.session[SESSION_KEY_PLAN]
pending_user = plan.context.get(PLAN_CONTEXT_PENDING_USER, None)
# Only save `authenticated_as` if there's a different pending user in the flow
# than the user that is authenticated
if pending_user and (
(pending_user.pk and pending_user.pk != self.user.get("pk"))
or (not pending_user.pk)
):
orig_user = self.user.copy()
self.user = {"authenticated_as": orig_user, **get_user(pending_user)}
# User 255.255.255.255 as fallback if IP cannot be determined
self.client_ip = ClientIPMiddleware.get_client_ip(request)
# Enrich event data
@ -564,7 +547,7 @@ class NotificationRule(SerializerModel, PolicyBindingModel):
default=NotificationSeverity.NOTICE,
help_text=_("Controls which severity level the created notifications will have."),
)
destination_group = models.ForeignKey(
group = models.ForeignKey(
Group,
help_text=_(
"Define which group of users this notification should be sent and shown to. "
@ -574,19 +557,6 @@ class NotificationRule(SerializerModel, PolicyBindingModel):
blank=True,
on_delete=models.SET_NULL,
)
destination_event_user = models.BooleanField(
default=False,
help_text=_(
"When enabled, notification will be sent to user the user that triggered the event."
"When destination_group is configured, notification is sent to both."
),
)
def destination_users(self, event: Event) -> Generator[User, Any]:
if self.destination_event_user and event.user.get("pk"):
yield User(pk=event.user.get("pk"))
if self.destination_group:
yield from self.destination_group.users.all()
@property
def serializer(self) -> type[Serializer]:

View File

@ -68,10 +68,14 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
if not result.passing:
return
if not trigger.group:
LOGGER.debug("e(trigger): trigger has no group", trigger=trigger)
return
LOGGER.debug("e(trigger): event trigger matched", trigger=trigger)
# Create the notification objects
for transport in trigger.transports.all():
for user in trigger.destination_users(event):
for user in trigger.group.users.all():
LOGGER.debug("created notification")
notification_transport.apply_async(
args=[

View File

@ -2,9 +2,7 @@
from django.test import TestCase
from authentik.events.context_processors.base import get_context_processors
from authentik.events.context_processors.geoip import GeoIPContextProcessor
from authentik.events.models import Event, EventAction
class TestGeoIP(TestCase):
@ -15,7 +13,8 @@ class TestGeoIP(TestCase):
def test_simple(self):
"""Test simple city wrapper"""
# IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
# IPs from
# https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
self.assertEqual(
self.reader.city_dict("2.125.160.216"),
{
@ -26,12 +25,3 @@ class TestGeoIP(TestCase):
"long": -1.25,
},
)
def test_special_chars(self):
"""Test city name with special characters"""
# IPs from https://github.com/maxmind/MaxMind-DB/blob/main/source-data/GeoLite2-City-Test.json
event = Event.new(EventAction.LOGIN)
event.client_ip = "89.160.20.112"
for processor in get_context_processors():
processor.enrich_event(event)
event.save()

View File

@ -8,11 +8,9 @@ from django.views.debug import SafeExceptionReporterFilter
from guardian.shortcuts import get_anonymous_user
from authentik.brands.models import Brand
from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_user
from authentik.core.models import Group
from authentik.events.models import Event
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan
from authentik.flows.views.executor import QS_QUERY, SESSION_KEY_PLAN
from authentik.flows.views.executor import QS_QUERY
from authentik.lib.generators import generate_id
from authentik.policies.dummy.models import DummyPolicy
@ -118,92 +116,3 @@ class TestEvents(TestCase):
"pk": brand.pk.hex,
},
)
def test_from_http_flow_pending_user(self):
"""Test request from flow request with a pending user"""
user = create_test_user()
session = self.client.session
plan = FlowPlan(generate_id())
plan.context[PLAN_CONTEXT_PENDING_USER] = user
session[SESSION_KEY_PLAN] = plan
session.save()
request = self.factory.get("/")
request.session = session
request.user = user
event = Event.new("unittest").from_http(request)
self.assertEqual(
event.user,
{
"email": user.email,
"pk": user.pk,
"username": user.username,
},
)
def test_from_http_flow_pending_user_anon(self):
"""Test request from flow request with a pending user"""
user = create_test_user()
anon = get_anonymous_user()
session = self.client.session
plan = FlowPlan(generate_id())
plan.context[PLAN_CONTEXT_PENDING_USER] = user
session[SESSION_KEY_PLAN] = plan
session.save()
request = self.factory.get("/")
request.session = session
request.user = anon
event = Event.new("unittest").from_http(request)
self.assertEqual(
event.user,
{
"authenticated_as": {
"pk": anon.pk,
"is_anonymous": True,
"username": "AnonymousUser",
"email": "",
},
"email": user.email,
"pk": user.pk,
"username": user.username,
},
)
def test_from_http_flow_pending_user_fake(self):
"""Test request from flow request with a pending user"""
user = User(
username=generate_id(),
email=generate_id(),
)
anon = get_anonymous_user()
session = self.client.session
plan = FlowPlan(generate_id())
plan.context[PLAN_CONTEXT_PENDING_USER] = user
session[SESSION_KEY_PLAN] = plan
session.save()
request = self.factory.get("/")
request.session = session
request.user = anon
event = Event.new("unittest").from_http(request)
self.assertEqual(
event.user,
{
"authenticated_as": {
"pk": anon.pk,
"is_anonymous": True,
"username": "AnonymousUser",
"email": "",
},
"email": user.email,
"pk": user.pk,
"username": user.username,
},
)

View File

@ -6,7 +6,6 @@ from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_user
from authentik.events.models import (
Event,
EventAction,
@ -35,7 +34,7 @@ class TestEventsNotifications(APITestCase):
def test_trigger_empty(self):
"""Test trigger without any policies attached"""
transport = NotificationTransport.objects.create(name=generate_id())
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport)
trigger.save()
@ -47,7 +46,7 @@ class TestEventsNotifications(APITestCase):
def test_trigger_single(self):
"""Test simple transport triggering"""
transport = NotificationTransport.objects.create(name=generate_id())
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport)
trigger.save()
matcher = EventMatcherPolicy.objects.create(
@ -60,25 +59,6 @@ class TestEventsNotifications(APITestCase):
Event.new(EventAction.CUSTOM_PREFIX).save()
self.assertEqual(execute_mock.call_count, 1)
def test_trigger_event_user(self):
"""Test trigger with event user"""
user = create_test_user()
transport = NotificationTransport.objects.create(name=generate_id())
trigger = NotificationRule.objects.create(name=generate_id(), destination_event_user=True)
trigger.transports.add(transport)
trigger.save()
matcher = EventMatcherPolicy.objects.create(
name="matcher", action=EventAction.CUSTOM_PREFIX
)
PolicyBinding.objects.create(target=trigger, policy=matcher, order=0)
execute_mock = MagicMock()
with patch("authentik.events.models.NotificationTransport.send", execute_mock):
Event.new(EventAction.CUSTOM_PREFIX).set_user(user).save()
self.assertEqual(execute_mock.call_count, 1)
notification: Notification = execute_mock.call_args[0][0]
self.assertEqual(notification.user, user)
def test_trigger_no_group(self):
"""Test trigger without group"""
trigger = NotificationRule.objects.create(name=generate_id())
@ -96,7 +76,7 @@ class TestEventsNotifications(APITestCase):
"""Test Policy error which would cause recursion"""
transport = NotificationTransport.objects.create(name=generate_id())
NotificationRule.objects.filter(name__startswith="default").delete()
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport)
trigger.save()
matcher = EventMatcherPolicy.objects.create(
@ -119,7 +99,7 @@ class TestEventsNotifications(APITestCase):
transport = NotificationTransport.objects.create(name=generate_id(), send_once=True)
NotificationRule.objects.filter(name__startswith="default").delete()
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport)
trigger.save()
matcher = EventMatcherPolicy.objects.create(
@ -143,7 +123,7 @@ class TestEventsNotifications(APITestCase):
name=generate_id(), webhook_mapping_body=mapping, mode=TransportMode.LOCAL
)
NotificationRule.objects.filter(name__startswith="default").delete()
trigger = NotificationRule.objects.create(name=generate_id(), destination_group=self.group)
trigger = NotificationRule.objects.create(name=generate_id(), group=self.group)
trigger.transports.add(transport)
matcher = EventMatcherPolicy.objects.create(
name="matcher", action=EventAction.CUSTOM_PREFIX

View File

@ -74,8 +74,8 @@ def model_to_dict(model: Model) -> dict[str, Any]:
}
def get_user(user: User | AnonymousUser) -> dict[str, Any]:
"""Convert user object to dictionary"""
def get_user(user: User | AnonymousUser, original_user: User | None = None) -> dict[str, Any]:
"""Convert user object to dictionary, optionally including the original user"""
if isinstance(user, AnonymousUser):
try:
user = get_anonymous_user()
@ -88,6 +88,10 @@ def get_user(user: User | AnonymousUser) -> dict[str, Any]:
}
if user.username == settings.ANONYMOUS_USER_NAME:
user_data["is_anonymous"] = True
if original_user:
original_data = get_user(original_user)
original_data["on_behalf_of"] = user_data
return original_data
return user_data

View File

@ -4,10 +4,8 @@ from unittest.mock import MagicMock, PropertyMock, patch
from urllib.parse import urlencode
from django.http import HttpRequest, HttpResponse
from django.test import override_settings
from django.test.client import RequestFactory
from django.urls import reverse
from rest_framework.exceptions import ParseError
from authentik.core.models import Group, User
from authentik.core.tests.utils import create_test_flow, create_test_user
@ -650,25 +648,3 @@ class TestFlowExecutor(FlowTestCase):
self.assertStageResponse(response, flow, component="ak-stage-identification")
response = self.client.post(exec_url, {"uid_field": user_other.username}, follow=True)
self.assertStageResponse(response, flow, component="ak-stage-access-denied")
@patch(
"authentik.flows.views.executor.to_stage_response",
TO_STAGE_RESPONSE_MOCK,
)
def test_invalid_json(self):
"""Test invalid JSON body"""
flow = create_test_flow()
FlowStageBinding.objects.create(
target=flow, stage=DummyStage.objects.create(name=generate_id()), order=0
)
url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
with override_settings(TEST=False, DEBUG=False):
self.client.logout()
response = self.client.post(url, data="{", content_type="application/json")
self.assertEqual(response.status_code, 200)
with self.assertRaises(ParseError):
self.client.logout()
response = self.client.post(url, data="{", content_type="application/json")
self.assertEqual(response.status_code, 200)

View File

@ -55,7 +55,7 @@ from authentik.flows.planner import (
FlowPlanner,
)
from authentik.flows.stage import AccessDeniedStage, StageView
from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception
from authentik.lib.sentry import SentryIgnoredException
from authentik.lib.utils.errors import exception_to_string
from authentik.lib.utils.reflection import all_subclasses, class_to_path
from authentik.lib.utils.urls import is_url_absolute, redirect_with_qs
@ -69,6 +69,7 @@ SESSION_KEY_APPLICATION_PRE = "authentik/flows/application_pre"
SESSION_KEY_GET = "authentik/flows/get"
SESSION_KEY_POST = "authentik/flows/post"
SESSION_KEY_HISTORY = "authentik/flows/history"
SESSION_KEY_AUTH_STARTED = "authentik/flows/auth_started"
QS_KEY_TOKEN = "flow_token" # nosec
QS_QUERY = "query"
@ -234,13 +235,12 @@ class FlowExecutorView(APIView):
"""Handle exception in stage execution"""
if settings.DEBUG or settings.TEST:
raise exc
capture_exception(exc)
self._logger.warning(exc)
if not should_ignore_exception(exc):
capture_exception(exc)
Event.new(
action=EventAction.SYSTEM_EXCEPTION,
message=exception_to_string(exc),
).from_http(self.request)
Event.new(
action=EventAction.SYSTEM_EXCEPTION,
message=exception_to_string(exc),
).from_http(self.request)
challenge = FlowErrorChallenge(self.request, exc)
challenge.is_valid(raise_exception=True)
return to_stage_response(self.request, HttpChallengeResponse(challenge))
@ -455,6 +455,7 @@ class FlowExecutorView(APIView):
SESSION_KEY_APPLICATION_PRE,
SESSION_KEY_PLAN,
SESSION_KEY_GET,
SESSION_KEY_AUTH_STARTED,
# We might need the initial POST payloads for later requests
# SESSION_KEY_POST,
# We don't delete the history on purpose, as a user might

View File

@ -6,7 +6,8 @@ from django.shortcuts import get_object_or_404
from ua_parser.user_agent_parser import Parse
from authentik.core.views.interface import InterfaceView
from authentik.flows.models import Flow
from authentik.flows.models import Flow, FlowDesignation
from authentik.flows.views.executor import SESSION_KEY_AUTH_STARTED
class FlowInterfaceView(InterfaceView):
@ -14,6 +15,12 @@ class FlowInterfaceView(InterfaceView):
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
flow = get_object_or_404(Flow, slug=self.kwargs.get("flow_slug"))
if (
not self.request.user.is_authenticated
and flow.designation == FlowDesignation.AUTHENTICATION
):
self.request.session[SESSION_KEY_AUTH_STARTED] = True
self.request.session.save()
kwargs["flow"] = flow
kwargs["flow_background_url"] = flow.background_url(self.request)
kwargs["inspector"] = "inspector" in self.request.GET

View File

@ -14,7 +14,6 @@ from django_redis.exceptions import ConnectionInterrupted
from docker.errors import DockerException
from h11 import LocalProtocolError
from ldap3.core.exceptions import LDAPException
from psycopg.errors import Error
from redis.exceptions import ConnectionError as RedisConnectionError
from redis.exceptions import RedisError, ResponseError
from rest_framework.exceptions import APIException
@ -45,49 +44,6 @@ class SentryIgnoredException(Exception):
"""Base Class for all errors that are suppressed, and not sent to sentry."""
ignored_classes = (
# Inbuilt types
KeyboardInterrupt,
ConnectionResetError,
OSError,
PermissionError,
# Django Errors
Error,
ImproperlyConfigured,
DatabaseError,
OperationalError,
InternalError,
ProgrammingError,
SuspiciousOperation,
ValidationError,
# Redis errors
RedisConnectionError,
ConnectionInterrupted,
RedisError,
ResponseError,
# websocket errors
ChannelFull,
WebSocketException,
LocalProtocolError,
# rest_framework error
APIException,
# celery errors
WorkerLostError,
CeleryError,
SoftTimeLimitExceeded,
# custom baseclass
SentryIgnoredException,
# ldap errors
LDAPException,
# Docker errors
DockerException,
# End-user errors
Http404,
# AsyncIO
CancelledError,
)
class SentryTransport(HttpTransport):
"""Custom sentry transport with custom user-agent"""
@ -145,17 +101,56 @@ def traces_sampler(sampling_context: dict) -> float:
return float(CONFIG.get("error_reporting.sample_rate", 0.1))
def should_ignore_exception(exc: Exception) -> bool:
"""Check if an exception should be dropped"""
return isinstance(exc, ignored_classes)
def before_send(event: dict, hint: dict) -> dict | None:
"""Check if error is database error, and ignore if so"""
from psycopg.errors import Error
ignored_classes = (
# Inbuilt types
KeyboardInterrupt,
ConnectionResetError,
OSError,
PermissionError,
# Django Errors
Error,
ImproperlyConfigured,
DatabaseError,
OperationalError,
InternalError,
ProgrammingError,
SuspiciousOperation,
ValidationError,
# Redis errors
RedisConnectionError,
ConnectionInterrupted,
RedisError,
ResponseError,
# websocket errors
ChannelFull,
WebSocketException,
LocalProtocolError,
# rest_framework error
APIException,
# celery errors
WorkerLostError,
CeleryError,
SoftTimeLimitExceeded,
# custom baseclass
SentryIgnoredException,
# ldap errors
LDAPException,
# Docker errors
DockerException,
# End-user errors
Http404,
# AsyncIO
CancelledError,
)
exc_value = None
if "exc_info" in hint:
_, exc_value, _ = hint["exc_info"]
if should_ignore_exception(exc_value):
if isinstance(exc_value, ignored_classes):
LOGGER.debug("dropping exception", exc=exc_value)
return None
if "logger" in event:

View File

@ -2,7 +2,7 @@
from django.test import TestCase
from authentik.lib.sentry import SentryIgnoredException, should_ignore_exception
from authentik.lib.sentry import SentryIgnoredException, before_send
class TestSentry(TestCase):
@ -10,8 +10,8 @@ class TestSentry(TestCase):
def test_error_not_sent(self):
"""Test SentryIgnoredError not sent"""
self.assertTrue(should_ignore_exception(SentryIgnoredException()))
self.assertIsNone(before_send({}, {"exc_info": (0, SentryIgnoredException(), 0)}))
def test_error_sent(self):
"""Test error sent"""
self.assertFalse(should_ignore_exception(ValueError()))
self.assertEqual({}, before_send({}, {"exc_info": (0, ValueError(), 0)}))

View File

@ -1,13 +1,15 @@
"""authentik outpost signals"""
from django.contrib.auth.signals import user_logged_out
from django.core.cache import cache
from django.db.models import Model
from django.db.models.signals import m2m_changed, post_save, pre_delete, pre_save
from django.dispatch import receiver
from django.http import HttpRequest
from structlog.stdlib import get_logger
from authentik.brands.models import Brand
from authentik.core.models import AuthenticatedSession, Provider
from authentik.core.models import AuthenticatedSession, Provider, User
from authentik.crypto.models import CertificateKeyPair
from authentik.lib.utils.reflection import class_to_path
from authentik.outposts.models import Outpost, OutpostServiceConnection
@ -80,6 +82,14 @@ def pre_delete_cleanup(sender, instance: Outpost, **_):
outpost_controller.delay(instance.pk.hex, action="down", from_cache=True)
@receiver(user_logged_out)
def logout_revoke_direct(sender: type[User], request: HttpRequest, **_):
"""Catch logout by direct logout and forward to providers"""
if not request.session or not request.session.session_key:
return
outpost_session_end.delay(request.session.session_key)
@receiver(pre_delete, sender=AuthenticatedSession)
def logout_revoke(sender: type[AuthenticatedSession], instance: AuthenticatedSession, **_):
"""Catch logout by expiring sessions being deleted"""

View File

@ -39,3 +39,4 @@ class AuthentikPoliciesConfig(ManagedAppConfig):
label = "authentik_policies"
verbose_name = "authentik Policies"
default = True
mountpoint = "policy/"

View File

@ -0,0 +1,89 @@
{% extends 'login/base_full.html' %}
{% load static %}
{% load i18n %}
{% block head %}
{{ block.super }}
<script>
let redirecting = false;
const checkAuth = async () => {
if (redirecting) return true;
const url = "{{ check_auth_url }}";
console.debug("authentik/policies/buffer: Checking authentication...");
try {
const result = await fetch(url, {
method: "HEAD",
});
if (result.status >= 400) {
return false
}
console.debug("authentik/policies/buffer: Continuing");
redirecting = true;
if ("{{ auth_req_method }}" === "post") {
document.querySelector("form").submit();
} else {
window.location.assign("{{ continue_url|escapejs }}");
}
} catch {
return false;
}
};
let timeout = 100;
let offset = 20;
let attempt = 0;
const main = async () => {
attempt += 1;
await checkAuth();
console.debug(`authentik/policies/buffer: Waiting ${timeout}ms...`);
setTimeout(main, timeout);
timeout += (offset * attempt);
if (timeout >= 2000) {
timeout = 2000;
}
}
document.addEventListener("visibilitychange", async () => {
if (document.hidden) return;
console.debug("authentik/policies/buffer: Checking authentication on tab activate...");
await checkAuth();
});
main();
</script>
{% endblock %}
{% block title %}
{% trans 'Waiting for authentication...' %} - {{ brand.branding_title }}
{% endblock %}
{% block card_title %}
{% trans 'Waiting for authentication...' %}
{% endblock %}
{% block card %}
<form class="pf-c-form" method="{{ auth_req_method }}" action="{{ continue_url }}">
{% if auth_req_method == "post" %}
{% for key, value in auth_req_body.items %}
<input type="hidden" name="{{ key }}" value="{{ value }}" />
{% endfor %}
{% endif %}
<div class="pf-c-empty-state">
<div class="pf-c-empty-state__content">
<div class="pf-c-empty-state__icon">
<span class="pf-c-spinner pf-m-xl" role="progressbar">
<span class="pf-c-spinner__clipper"></span>
<span class="pf-c-spinner__lead-ball"></span>
<span class="pf-c-spinner__tail-ball"></span>
</span>
</div>
<h1 class="pf-c-title pf-m-lg">
{% trans "You're already authenticating in another tab. This page will refresh once authentication is completed." %}
</h1>
</div>
</div>
<div class="pf-c-form__group pf-m-action">
<a href="{{ auth_req_url }}" class="pf-c-button pf-m-primary pf-m-block">
{% trans "Authenticate in this tab" %}
</a>
</div>
</form>
{% endblock %}

View File

@ -0,0 +1,121 @@
from django.contrib.auth.models import AnonymousUser
from django.contrib.sessions.middleware import SessionMiddleware
from django.http import HttpResponse
from django.test import RequestFactory, TestCase
from django.urls import reverse
from authentik.core.models import Application, Provider
from authentik.core.tests.utils import create_test_flow, create_test_user
from authentik.flows.models import FlowDesignation
from authentik.flows.planner import FlowPlan
from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.generators import generate_id
from authentik.lib.tests.utils import dummy_get_response
from authentik.policies.views import (
QS_BUFFER_ID,
SESSION_KEY_BUFFER,
BufferedPolicyAccessView,
BufferView,
PolicyAccessView,
)
class TestPolicyViews(TestCase):
"""Test PolicyAccessView"""
def setUp(self):
super().setUp()
self.factory = RequestFactory()
self.user = create_test_user()
def test_pav(self):
"""Test simple policy access view"""
provider = Provider.objects.create(
name=generate_id(),
)
app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
class TestView(PolicyAccessView):
def resolve_provider_application(self):
self.provider = provider
self.application = app
def get(self, *args, **kwargs):
return HttpResponse("foo")
req = self.factory.get("/")
req.user = self.user
res = TestView.as_view()(req)
self.assertEqual(res.status_code, 200)
self.assertEqual(res.content, b"foo")
def test_pav_buffer(self):
"""Test simple policy access view"""
provider = Provider.objects.create(
name=generate_id(),
)
app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
flow = create_test_flow(FlowDesignation.AUTHENTICATION)
class TestView(BufferedPolicyAccessView):
def resolve_provider_application(self):
self.provider = provider
self.application = app
def get(self, *args, **kwargs):
return HttpResponse("foo")
req = self.factory.get("/")
req.user = AnonymousUser()
middleware = SessionMiddleware(dummy_get_response)
middleware.process_request(req)
req.session[SESSION_KEY_PLAN] = FlowPlan(flow.pk)
req.session.save()
res = TestView.as_view()(req)
self.assertEqual(res.status_code, 302)
self.assertTrue(res.url.startswith(reverse("authentik_policies:buffer")))
def test_pav_buffer_skip(self):
"""Test simple policy access view (skip buffer)"""
provider = Provider.objects.create(
name=generate_id(),
)
app = Application.objects.create(name=generate_id(), slug=generate_id(), provider=provider)
flow = create_test_flow(FlowDesignation.AUTHENTICATION)
class TestView(BufferedPolicyAccessView):
def resolve_provider_application(self):
self.provider = provider
self.application = app
def get(self, *args, **kwargs):
return HttpResponse("foo")
req = self.factory.get("/?skip_buffer=true")
req.user = AnonymousUser()
middleware = SessionMiddleware(dummy_get_response)
middleware.process_request(req)
req.session[SESSION_KEY_PLAN] = FlowPlan(flow.pk)
req.session.save()
res = TestView.as_view()(req)
self.assertEqual(res.status_code, 302)
self.assertTrue(res.url.startswith(reverse("authentik_flows:default-authentication")))
def test_buffer(self):
"""Test buffer view"""
uid = generate_id()
req = self.factory.get(f"/?{QS_BUFFER_ID}={uid}")
req.user = AnonymousUser()
middleware = SessionMiddleware(dummy_get_response)
middleware.process_request(req)
ts = generate_id()
req.session[SESSION_KEY_BUFFER % uid] = {
"method": "get",
"body": {},
"url": f"/{ts}",
}
req.session.save()
res = BufferView.as_view()(req)
self.assertEqual(res.status_code, 200)
self.assertIn(ts, res.render().content.decode())

View File

@ -1,7 +1,14 @@
"""API URLs"""
from django.urls import path
from authentik.policies.api.bindings import PolicyBindingViewSet
from authentik.policies.api.policies import PolicyViewSet
from authentik.policies.views import BufferView
urlpatterns = [
path("buffer", BufferView.as_view(), name="buffer"),
]
api_urlpatterns = [
("policies/all", PolicyViewSet),

View File

@ -1,23 +1,37 @@
"""authentik access helper classes"""
from typing import Any
from uuid import uuid4
from django.contrib import messages
from django.contrib.auth.mixins import AccessMixin
from django.contrib.auth.views import redirect_to_login
from django.http import HttpRequest, HttpResponse
from django.http import HttpRequest, HttpResponse, QueryDict
from django.shortcuts import redirect
from django.urls import reverse
from django.utils.http import urlencode
from django.utils.translation import gettext as _
from django.views.generic.base import View
from django.views.generic.base import TemplateView, View
from structlog.stdlib import get_logger
from authentik.core.models import Application, Provider, User
from authentik.flows.views.executor import SESSION_KEY_APPLICATION_PRE, SESSION_KEY_POST
from authentik.flows.models import Flow, FlowDesignation
from authentik.flows.planner import FlowPlan
from authentik.flows.views.executor import (
SESSION_KEY_APPLICATION_PRE,
SESSION_KEY_AUTH_STARTED,
SESSION_KEY_PLAN,
SESSION_KEY_POST,
)
from authentik.lib.sentry import SentryIgnoredException
from authentik.policies.denied import AccessDeniedResponse
from authentik.policies.engine import PolicyEngine
from authentik.policies.types import PolicyRequest, PolicyResult
LOGGER = get_logger()
QS_BUFFER_ID = "af_bf_id"
QS_SKIP_BUFFER = "skip_buffer"
SESSION_KEY_BUFFER = "authentik/policies/pav_buffer/%s"
class RequestValidationError(SentryIgnoredException):
@ -125,3 +139,65 @@ class PolicyAccessView(AccessMixin, View):
for message in result.messages:
messages.error(self.request, _(message))
return result
def url_with_qs(url: str, **kwargs):
"""Update/set querystring of `url` with the parameters in `kwargs`. Original query string
parameters are retained"""
if "?" not in url:
return url + f"?{urlencode(kwargs)}"
url, _, qs = url.partition("?")
qs = QueryDict(qs, mutable=True)
qs.update(kwargs)
return url + f"?{urlencode(qs.items())}"
class BufferView(TemplateView):
"""Buffer view"""
template_name = "policies/buffer.html"
def get_context_data(self, **kwargs):
buf_id = self.request.GET.get(QS_BUFFER_ID)
buffer: dict = self.request.session.get(SESSION_KEY_BUFFER % buf_id)
kwargs["auth_req_method"] = buffer["method"]
kwargs["auth_req_body"] = buffer["body"]
kwargs["auth_req_url"] = url_with_qs(buffer["url"], **{QS_SKIP_BUFFER: True})
kwargs["check_auth_url"] = reverse("authentik_api:user-me")
kwargs["continue_url"] = url_with_qs(buffer["url"], **{QS_BUFFER_ID: buf_id})
return super().get_context_data(**kwargs)
class BufferedPolicyAccessView(PolicyAccessView):
"""PolicyAccessView which buffers access requests in case the user is not logged in"""
def handle_no_permission(self):
plan: FlowPlan | None = self.request.session.get(SESSION_KEY_PLAN)
authenticating = self.request.session.get(SESSION_KEY_AUTH_STARTED)
if plan:
flow = Flow.objects.filter(pk=plan.flow_pk).first()
if not flow or flow.designation != FlowDesignation.AUTHENTICATION:
LOGGER.debug("Not buffering request, no flow or flow not for authentication")
return super().handle_no_permission()
if not plan and authenticating is None:
LOGGER.debug("Not buffering request, no flow plan active")
return super().handle_no_permission()
if self.request.GET.get(QS_SKIP_BUFFER):
LOGGER.debug("Not buffering request, explicit skip")
return super().handle_no_permission()
buffer_id = str(uuid4())
LOGGER.debug("Buffering access request", bf_id=buffer_id)
self.request.session[SESSION_KEY_BUFFER % buffer_id] = {
"body": self.request.POST,
"url": self.request.build_absolute_uri(self.request.get_full_path()),
"method": self.request.method.lower(),
}
return redirect(
url_with_qs(reverse("authentik_policies:buffer"), **{QS_BUFFER_ID: buffer_id})
)
def dispatch(self, request, *args, **kwargs):
response = super().dispatch(request, *args, **kwargs)
if QS_BUFFER_ID in self.request.GET:
self.request.session.pop(SESSION_KEY_BUFFER % self.request.GET[QS_BUFFER_ID], None)
return response

View File

@ -1,10 +1,23 @@
from django.contrib.auth.signals import user_logged_out
from django.db.models.signals import post_save, pre_delete
from django.dispatch import receiver
from django.http import HttpRequest
from authentik.core.models import AuthenticatedSession, User
from authentik.providers.oauth2.models import AccessToken, DeviceToken, RefreshToken
@receiver(user_logged_out)
def user_logged_out_oauth_tokens_removal(sender, request: HttpRequest, user: User, **_):
"""Revoke tokens upon user logout"""
if not request.session or not request.session.session_key:
return
AccessToken.objects.filter(
user=user,
session__session__session_key=request.session.session_key,
).delete()
@receiver(pre_delete, sender=AuthenticatedSession)
def user_session_deleted_oauth_tokens_removal(sender, instance: AuthenticatedSession, **_):
"""Revoke tokens upon user logout"""

View File

@ -387,7 +387,8 @@ class TestAuthorize(OAuthTestCase):
self.assertEqual(
response.url,
(
f"http://localhost#id_token={provider.encode(token.id_token.to_dict())}"
f"http://localhost#access_token={token.token}"
f"&id_token={provider.encode(token.id_token.to_dict())}"
f"&token_type={TOKEN_TYPE}"
f"&expires_in={int(expires)}&state={state}"
),
@ -562,6 +563,7 @@ class TestAuthorize(OAuthTestCase):
"url": "http://localhost",
"title": f"Redirecting to {app.name}...",
"attrs": {
"access_token": token.token,
"id_token": provider.encode(token.id_token.to_dict()),
"token_type": TOKEN_TYPE,
"expires_in": "3600",

View File

@ -30,7 +30,7 @@ from authentik.flows.stage import StageView
from authentik.lib.utils.time import timedelta_from_string
from authentik.lib.views import bad_request_message
from authentik.policies.types import PolicyRequest
from authentik.policies.views import PolicyAccessView, RequestValidationError
from authentik.policies.views import BufferedPolicyAccessView, RequestValidationError
from authentik.providers.oauth2.constants import (
PKCE_METHOD_PLAIN,
PKCE_METHOD_S256,
@ -150,12 +150,12 @@ class OAuthAuthorizationParams:
self.check_redirect_uri()
self.check_grant()
self.check_scope(github_compat)
self.check_nonce()
self.check_code_challenge()
if self.request:
raise AuthorizeError(
self.redirect_uri, "request_not_supported", self.grant_type, self.state
)
self.check_nonce()
self.check_code_challenge()
def check_grant(self):
"""Check grant"""
@ -326,7 +326,7 @@ class OAuthAuthorizationParams:
return code
class AuthorizationFlowInitView(PolicyAccessView):
class AuthorizationFlowInitView(BufferedPolicyAccessView):
"""OAuth2 Flow initializer, checks access to application and starts flow"""
params: OAuthAuthorizationParams
@ -630,6 +630,7 @@ class OAuthFulfillmentStage(StageView):
if self.params.response_type in [
ResponseTypes.ID_TOKEN_TOKEN,
ResponseTypes.CODE_ID_TOKEN_TOKEN,
ResponseTypes.ID_TOKEN,
ResponseTypes.CODE_TOKEN,
]:
query_fragment["access_token"] = token.token

View File

@ -66,10 +66,7 @@ class RACClientConsumer(AsyncWebsocketConsumer):
def init_outpost_connection(self):
"""Initialize guac connection settings"""
self.token = (
ConnectionToken.filter_not_expired(
token=self.scope["url_route"]["kwargs"]["token"],
session__session__session_key=self.scope["session"].session_key,
)
ConnectionToken.filter_not_expired(token=self.scope["url_route"]["kwargs"]["token"])
.select_related("endpoint", "provider", "session", "session__user")
.first()
)

View File

@ -2,11 +2,13 @@
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.contrib.auth.signals import user_logged_out
from django.core.cache import cache
from django.db.models.signals import post_delete, post_save, pre_delete
from django.dispatch import receiver
from django.http import HttpRequest
from authentik.core.models import AuthenticatedSession
from authentik.core.models import AuthenticatedSession, User
from authentik.providers.rac.api.endpoints import user_endpoint_cache_key
from authentik.providers.rac.consumer_client import (
RAC_CLIENT_GROUP_SESSION,
@ -15,6 +17,21 @@ from authentik.providers.rac.consumer_client import (
from authentik.providers.rac.models import ConnectionToken, Endpoint
@receiver(user_logged_out)
def user_logged_out_session(sender, request: HttpRequest, user: User, **_):
"""Disconnect any open RAC connections"""
if not request.session or not request.session.session_key:
return
layer = get_channel_layer()
async_to_sync(layer.group_send)(
RAC_CLIENT_GROUP_SESSION
% {
"session": request.session.session_key,
},
{"type": "event.disconnect", "reason": "session_logout"},
)
@receiver(pre_delete, sender=AuthenticatedSession)
def user_session_deleted(sender, instance: AuthenticatedSession, **_):
layer = get_channel_layer()

View File

@ -49,7 +49,6 @@ class TestEndpointsAPI(APITestCase):
self.assertJSONEqual(
response.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,
@ -102,7 +101,6 @@ class TestEndpointsAPI(APITestCase):
self.assertJSONEqual(
response.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,

View File

@ -87,22 +87,3 @@ class TestRACViews(APITestCase):
)
body = loads(flow_response.content)
self.assertEqual(body["component"], "ak-stage-access-denied")
def test_different_session(self):
"""Test request"""
self.client.force_login(self.user)
response = self.client.get(
reverse(
"authentik_providers_rac:start",
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
)
)
self.assertEqual(response.status_code, 302)
flow_response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
)
body = loads(flow_response.content)
next_url = body["to"]
self.client.logout()
final_response = self.client.get(next_url)
self.assertEqual(final_response.url, reverse("authentik_core:if-user"))

View File

@ -18,14 +18,11 @@ from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, FlowPlanner
from authentik.flows.stage import RedirectStage
from authentik.lib.utils.time import timedelta_from_string
from authentik.policies.engine import PolicyEngine
from authentik.policies.views import PolicyAccessView
from authentik.policies.views import BufferedPolicyAccessView
from authentik.providers.rac.models import ConnectionToken, Endpoint, RACProvider
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
PLAN_CONNECTION_SETTINGS = "connection_settings"
class RACStartView(PolicyAccessView):
class RACStartView(BufferedPolicyAccessView):
"""Start a RAC connection by checking access and creating a connection token"""
endpoint: Endpoint
@ -68,10 +65,7 @@ class RACInterface(InterfaceView):
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
# Early sanity check to ensure token still exists
token = ConnectionToken.filter_not_expired(
token=self.kwargs["token"],
session__session__session_key=request.session.session_key,
).first()
token = ConnectionToken.filter_not_expired(token=self.kwargs["token"]).first()
if not token:
return redirect("authentik_core:if-user")
self.token = token
@ -115,15 +109,10 @@ class RACFinalStage(RedirectStage):
return super().dispatch(request, *args, **kwargs)
def get_challenge(self, *args, **kwargs) -> RedirectChallenge:
settings = self.executor.plan.context.get(PLAN_CONNECTION_SETTINGS)
if not settings:
settings = self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}).get(
PLAN_CONNECTION_SETTINGS
)
token = ConnectionToken.objects.create(
provider=self.provider,
endpoint=self.endpoint,
settings=settings or {},
settings=self.executor.plan.context.get("connection_settings", {}),
session=self.request.session["authenticatedsession"],
expires=now() + timedelta_from_string(self.provider.connection_expiry),
expiring=True,

View File

@ -35,8 +35,8 @@ REQUEST_KEY_SAML_SIG_ALG = "SigAlg"
REQUEST_KEY_SAML_RESPONSE = "SAMLResponse"
REQUEST_KEY_RELAY_STATE = "RelayState"
SESSION_KEY_AUTH_N_REQUEST = "authentik/providers/saml/authn_request"
SESSION_KEY_LOGOUT_REQUEST = "authentik/providers/saml/logout_request"
PLAN_CONTEXT_SAML_AUTH_N_REQUEST = "authentik/providers/saml/authn_request"
PLAN_CONTEXT_SAML_LOGOUT_REQUEST = "authentik/providers/saml/logout_request"
# This View doesn't have a URL on purpose, as its called by the FlowExecutor
@ -50,10 +50,11 @@ class SAMLFlowFinalView(ChallengeStageView):
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
application: Application = self.executor.plan.context[PLAN_CONTEXT_APPLICATION]
provider: SAMLProvider = get_object_or_404(SAMLProvider, pk=application.provider_id)
if SESSION_KEY_AUTH_N_REQUEST not in self.request.session:
if PLAN_CONTEXT_SAML_AUTH_N_REQUEST not in self.executor.plan.context:
self.logger.warning("No AuthNRequest in context")
return self.executor.stage_invalid()
auth_n_request: AuthNRequest = self.request.session.pop(SESSION_KEY_AUTH_N_REQUEST)
auth_n_request: AuthNRequest = self.executor.plan.context[PLAN_CONTEXT_SAML_AUTH_N_REQUEST]
try:
response = AssertionProcessor(provider, request, auth_n_request).build_response()
except SAMLException as exc:
@ -106,6 +107,3 @@ class SAMLFlowFinalView(ChallengeStageView):
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:
# We'll never get here since the challenge redirects to the SP
return HttpResponseBadRequest()
def cleanup(self):
self.request.session.pop(SESSION_KEY_AUTH_N_REQUEST, None)

View File

@ -19,9 +19,9 @@ from authentik.providers.saml.exceptions import CannotHandleAssertion
from authentik.providers.saml.models import SAMLProvider
from authentik.providers.saml.processors.logout_request_parser import LogoutRequestParser
from authentik.providers.saml.views.flows import (
PLAN_CONTEXT_SAML_LOGOUT_REQUEST,
REQUEST_KEY_RELAY_STATE,
REQUEST_KEY_SAML_REQUEST,
SESSION_KEY_LOGOUT_REQUEST,
)
LOGGER = get_logger()
@ -33,6 +33,10 @@ class SAMLSLOView(PolicyAccessView):
flow: Flow
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.plan_context = {}
def resolve_provider_application(self):
self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"])
self.provider: SAMLProvider = get_object_or_404(
@ -59,6 +63,7 @@ class SAMLSLOView(PolicyAccessView):
request,
{
PLAN_CONTEXT_APPLICATION: self.application,
**self.plan_context,
},
)
plan.append_stage(in_memory_stage(SessionEndStage))
@ -83,7 +88,7 @@ class SAMLSLOBindingRedirectView(SAMLSLOView):
self.request.GET[REQUEST_KEY_SAML_REQUEST],
relay_state=self.request.GET.get(REQUEST_KEY_RELAY_STATE, None),
)
self.request.session[SESSION_KEY_LOGOUT_REQUEST] = logout_request
self.plan_context[PLAN_CONTEXT_SAML_LOGOUT_REQUEST] = logout_request
except CannotHandleAssertion as exc:
Event.new(
EventAction.CONFIGURATION_ERROR,
@ -111,7 +116,7 @@ class SAMLSLOBindingPOSTView(SAMLSLOView):
payload[REQUEST_KEY_SAML_REQUEST],
relay_state=payload.get(REQUEST_KEY_RELAY_STATE, None),
)
self.request.session[SESSION_KEY_LOGOUT_REQUEST] = logout_request
self.plan_context[PLAN_CONTEXT_SAML_LOGOUT_REQUEST] = logout_request
except CannotHandleAssertion as exc:
LOGGER.info(str(exc))
return bad_request_message(self.request, str(exc))

View File

@ -15,16 +15,16 @@ from authentik.flows.models import in_memory_stage
from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO, FlowPlanner
from authentik.flows.views.executor import SESSION_KEY_POST
from authentik.lib.views import bad_request_message
from authentik.policies.views import PolicyAccessView
from authentik.policies.views import BufferedPolicyAccessView
from authentik.providers.saml.exceptions import CannotHandleAssertion
from authentik.providers.saml.models import SAMLBindings, SAMLProvider
from authentik.providers.saml.processors.authn_request_parser import AuthNRequestParser
from authentik.providers.saml.views.flows import (
PLAN_CONTEXT_SAML_AUTH_N_REQUEST,
REQUEST_KEY_RELAY_STATE,
REQUEST_KEY_SAML_REQUEST,
REQUEST_KEY_SAML_SIG_ALG,
REQUEST_KEY_SAML_SIGNATURE,
SESSION_KEY_AUTH_N_REQUEST,
SAMLFlowFinalView,
)
from authentik.stages.consent.stage import (
@ -35,10 +35,14 @@ from authentik.stages.consent.stage import (
LOGGER = get_logger()
class SAMLSSOView(PolicyAccessView):
class SAMLSSOView(BufferedPolicyAccessView):
"""SAML SSO Base View, which plans a flow and injects our final stage.
Calls get/post handler."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.plan_context = {}
def resolve_provider_application(self):
self.application = get_object_or_404(Application, slug=self.kwargs["application_slug"])
self.provider: SAMLProvider = get_object_or_404(
@ -68,6 +72,7 @@ class SAMLSSOView(PolicyAccessView):
PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.")
% {"application": self.application.name},
PLAN_CONTEXT_CONSENT_PERMISSIONS: [],
**self.plan_context,
},
)
except FlowNonApplicableException:
@ -83,7 +88,7 @@ class SAMLSSOView(PolicyAccessView):
def post(self, request: HttpRequest, application_slug: str) -> HttpResponse:
"""GET and POST use the same handler, but we can't
override .dispatch easily because PolicyAccessView's dispatch"""
override .dispatch easily because BufferedPolicyAccessView's dispatch"""
return self.get(request, application_slug)
@ -103,7 +108,7 @@ class SAMLSSOBindingRedirectView(SAMLSSOView):
self.request.GET.get(REQUEST_KEY_SAML_SIGNATURE),
self.request.GET.get(REQUEST_KEY_SAML_SIG_ALG),
)
self.request.session[SESSION_KEY_AUTH_N_REQUEST] = auth_n_request
self.plan_context[PLAN_CONTEXT_SAML_AUTH_N_REQUEST] = auth_n_request
except CannotHandleAssertion as exc:
Event.new(
EventAction.CONFIGURATION_ERROR,
@ -137,7 +142,7 @@ class SAMLSSOBindingPOSTView(SAMLSSOView):
payload[REQUEST_KEY_SAML_REQUEST],
payload.get(REQUEST_KEY_RELAY_STATE),
)
self.request.session[SESSION_KEY_AUTH_N_REQUEST] = auth_n_request
self.plan_context[PLAN_CONTEXT_SAML_AUTH_N_REQUEST] = auth_n_request
except CannotHandleAssertion as exc:
LOGGER.info(str(exc))
return bad_request_message(self.request, str(exc))
@ -151,4 +156,4 @@ class SAMLSSOBindingInitView(SAMLSSOView):
"""Create SAML Response from scratch"""
LOGGER.debug("No SAML Request, using IdP-initiated flow.")
auth_n_request = AuthNRequestParser(self.provider).idp_initiated()
self.request.session[SESSION_KEY_AUTH_N_REQUEST] = auth_n_request
self.plan_context[PLAN_CONTEXT_SAML_AUTH_N_REQUEST] = auth_n_request

View File

@ -5,6 +5,7 @@ from itertools import batched
from django.db import transaction
from pydantic import ValidationError
from pydanticscim.group import GroupMember
from pydanticscim.responses import PatchOp
from authentik.core.models import Group
from authentik.lib.sync.mapper import PropertyMappingManager
@ -19,12 +20,7 @@ from authentik.providers.scim.clients.base import SCIMClient
from authentik.providers.scim.clients.exceptions import (
SCIMRequestException,
)
from authentik.providers.scim.clients.schema import (
SCIM_GROUP_SCHEMA,
PatchOp,
PatchOperation,
PatchRequest,
)
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOperation, PatchRequest
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
from authentik.providers.scim.models import (
SCIMMapping,

View File

@ -1,7 +1,5 @@
"""Custom SCIM schemas"""
from enum import Enum
from pydantic import Field
from pydanticscim.group import Group as BaseGroup
from pydanticscim.responses import PatchOperation as BasePatchOperation
@ -67,21 +65,6 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
)
class PatchOp(str, Enum):
replace = "replace"
remove = "remove"
add = "add"
@classmethod
def _missing_(cls, value):
value = value.lower()
for member in cls:
if member.lower() == value:
return member
return None
class PatchRequest(BasePatchRequest):
"""PatchRequest which correctly sets schemas"""
@ -91,7 +74,6 @@ class PatchRequest(BasePatchRequest):
class PatchOperation(BasePatchOperation):
"""PatchOperation with optional path"""
op: PatchOp
path: str | None

View File

@ -44,7 +44,6 @@ class TestRBACRoleAPI(APITestCase):
self.assertJSONEqual(
res.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,

View File

@ -46,7 +46,6 @@ class TestRBACUserAPI(APITestCase):
self.assertJSONEqual(
res.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,

View File

@ -38,7 +38,6 @@ class TestAPIPerms(APITestCase):
self.assertJSONEqual(
res.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,
@ -74,7 +73,6 @@ class TestAPIPerms(APITestCase):
self.assertJSONEqual(
res.content.decode(),
{
"autocomplete": {},
"pagination": {
"next": 0,
"previous": 0,

View File

@ -9,14 +9,13 @@ https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/
import django
from channels.routing import ProtocolTypeRouter, URLRouter
from defusedxml import defuse_stdlib
from django.core.asgi import get_asgi_application
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
from authentik.root.setup import setup
# DJANGO_SETTINGS_MODULE is set in gunicorn.conf.py
setup()
defuse_stdlib()
django.setup()

View File

@ -27,7 +27,7 @@ from structlog.stdlib import get_logger
from tenant_schemas_celery.app import CeleryApp as TenantAwareCeleryApp
from authentik import get_full_version
from authentik.lib.sentry import should_ignore_exception
from authentik.lib.sentry import before_send
from authentik.lib.utils.errors import exception_to_string
# set the default Django settings module for the 'celery' program.
@ -81,7 +81,7 @@ def task_error_hook(task_id: str, exception: Exception, traceback, *args, **kwar
LOGGER.warning("Task failure", task_id=task_id.replace("-", ""), exc=exception)
CTX_TASK_ID.set(...)
if not should_ignore_exception(exception):
if before_send({}, {"exc_info": (None, exception, None)}) is not None:
Event.new(
EventAction.SYSTEM_EXCEPTION, message=exception_to_string(exception), task_id=task_id
).save()

View File

@ -1,49 +1,13 @@
"""authentik database backend"""
from django.core.checks import Warning
from django.db.backends.base.validation import BaseDatabaseValidation
from django_tenants.postgresql_backend.base import DatabaseWrapper as BaseDatabaseWrapper
from authentik.lib.config import CONFIG
class DatabaseValidation(BaseDatabaseValidation):
def check(self, **kwargs):
return self._check_encoding()
def _check_encoding(self):
"""Throw a warning when the server_encoding is not UTF-8 or
server_encoding and client_encoding are mismatched"""
messages = []
with self.connection.cursor() as cursor:
cursor.execute("SHOW server_encoding;")
server_encoding = cursor.fetchone()[0]
cursor.execute("SHOW client_encoding;")
client_encoding = cursor.fetchone()[0]
if server_encoding != client_encoding:
messages.append(
Warning(
"PostgreSQL Server and Client encoding are mismatched: Server: "
f"{server_encoding}, Client: {client_encoding}",
id="ak.db.W001",
)
)
if server_encoding != "UTF8":
messages.append(
Warning(
f"PostgreSQL Server encoding is not UTF8: {server_encoding}",
id="ak.db.W002",
)
)
return messages
class DatabaseWrapper(BaseDatabaseWrapper):
"""database backend which supports rotating credentials"""
validation_class = DatabaseValidation
def get_connection_params(self):
"""Refresh DB credentials before getting connection params"""
conn_params = super().get_connection_params()

View File

@ -446,8 +446,6 @@ _DISALLOWED_ITEMS = [
"MIDDLEWARE",
"AUTHENTICATION_BACKENDS",
"CELERY",
"SPECTACULAR_SETTINGS",
"REST_FRAMEWORK",
]
SILENCED_SYSTEM_CHECKS = [
@ -470,8 +468,6 @@ def _update_settings(app_path: str):
TENANT_APPS.extend(getattr(settings_module, "TENANT_APPS", []))
MIDDLEWARE.extend(getattr(settings_module, "MIDDLEWARE", []))
AUTHENTICATION_BACKENDS.extend(getattr(settings_module, "AUTHENTICATION_BACKENDS", []))
SPECTACULAR_SETTINGS.update(getattr(settings_module, "SPECTACULAR_SETTINGS", {}))
REST_FRAMEWORK.update(getattr(settings_module, "REST_FRAMEWORK", {}))
CELERY["beat_schedule"].update(getattr(settings_module, "CELERY_BEAT_SCHEDULE", {}))
for _attr in dir(settings_module):
if not _attr.startswith("__") and _attr not in _DISALLOWED_ITEMS:

View File

@ -1,26 +0,0 @@
import os
import warnings
from cryptography.hazmat.backends.openssl.backend import backend
from defusedxml import defuse_stdlib
from authentik.lib.config import CONFIG
def setup():
warnings.filterwarnings("ignore", "SelectableGroups dict interface")
warnings.filterwarnings(
"ignore",
"defusedxml.lxml is no longer supported and will be removed in a future release.",
)
warnings.filterwarnings(
"ignore",
"defusedxml.cElementTree is deprecated, import from defusedxml.ElementTree instead.",
)
defuse_stdlib()
if CONFIG.get_bool("compliance.fips.enabled", False):
backend._enable_fips()
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")

View File

@ -11,8 +11,6 @@ from django.contrib.contenttypes.models import ContentType
from django.test.runner import DiscoverRunner
from structlog.stdlib import get_logger
from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
from authentik.lib.config import CONFIG
from authentik.lib.sentry import sentry_init
from authentik.root.signals import post_startup, pre_startup, startup
@ -78,9 +76,6 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover
for key, value in test_config.items():
CONFIG.set(key, value)
ASN_CONTEXT_PROCESSOR.load()
GEOIP_CONTEXT_PROCESSOR.load()
sentry_init()
self.logger.debug("Test environment configured")

View File

@ -71,31 +71,37 @@ def ldap_sync_single(source_pk: str):
return
# Delete all sync tasks from the cache
DBSystemTask.objects.filter(name="ldap_sync", uid__startswith=source.slug).delete()
# The order of these operations needs to be preserved as each depends on the previous one(s)
# 1. User and group sync can happen simultaneously
# 2. Membership sync needs to run afterwards
# 3. Finally, user and group deletions can happen simultaneously
user_group_sync = ldap_sync_paginator(source, UserLDAPSynchronizer) + ldap_sync_paginator(
source, GroupLDAPSynchronizer
task = chain(
# User and group sync can happen at once, they have no dependencies on each other
group(
ldap_sync_paginator(source, UserLDAPSynchronizer)
+ ldap_sync_paginator(source, GroupLDAPSynchronizer),
),
# Membership sync needs to run afterwards
group(
ldap_sync_paginator(source, MembershipLDAPSynchronizer),
),
# Finally, deletions. What we'd really like to do here is something like
# ```
# user_identifiers = <ldap query>
# User.objects.exclude(
# usersourceconnection__identifier__in=user_uniqueness_identifiers,
# ).delete()
# ```
# This runs into performance issues in large installations. So instead we spread the
# work out into three steps:
# 1. Get every object from the LDAP source.
# 2. Mark every object as "safe" in the database. This is quick, but any error could
# mean deleting users which should not be deleted, so we do it immediately, in
# large chunks, and only queue the deletion step afterwards.
# 3. Delete every unmarked item. This is slow, so we spread it over many tasks in
# small chunks.
group(
ldap_sync_paginator(source, UserLDAPForwardDeletion)
+ ldap_sync_paginator(source, GroupLDAPForwardDeletion),
),
)
membership_sync = ldap_sync_paginator(source, MembershipLDAPSynchronizer)
user_group_deletion = ldap_sync_paginator(
source, UserLDAPForwardDeletion
) + ldap_sync_paginator(source, GroupLDAPForwardDeletion)
# Celery is buggy with empty groups, so we are careful only to add non-empty groups.
# See https://github.com/celery/celery/issues/9772
task_groups = []
if user_group_sync:
task_groups.append(group(user_group_sync))
if membership_sync:
task_groups.append(group(membership_sync))
if user_group_deletion:
task_groups.append(group(user_group_deletion))
all_tasks = chain(task_groups)
all_tasks()
task()
def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> list:

View File

@ -1,277 +0,0 @@
"""Test SCIM Group"""
from json import dumps
from uuid import uuid4
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.models import Group
from authentik.core.tests.utils import create_test_user
from authentik.events.models import Event, EventAction
from authentik.lib.generators import generate_id
from authentik.providers.scim.clients.schema import Group as SCIMGroupSchema
from authentik.sources.scim.models import (
SCIMSource,
SCIMSourceGroup,
)
from authentik.sources.scim.views.v2.base import SCIM_CONTENT_TYPE
class TestSCIMGroups(APITestCase):
"""Test SCIM Group view"""
def setUp(self) -> None:
self.source = SCIMSource.objects.create(name=generate_id(), slug=generate_id())
def test_group_list(self):
"""Test full group list"""
response = self.client.get(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 200)
def test_group_list_single(self):
"""Test full group list (single group)"""
group = Group.objects.create(name=generate_id())
user = create_test_user()
group.users.add(user)
SCIMSourceGroup.objects.create(
source=self.source,
group=group,
id=str(uuid4()),
)
response = self.client.get(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
"group_id": str(group.pk),
},
),
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=200)
SCIMGroupSchema.model_validate_json(response.content, strict=True)
def test_group_create(self):
"""Test group create"""
ext_id = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
data=dumps({"displayName": generate_id(), "externalId": ext_id}),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 201)
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
self.assertTrue(
Event.objects.filter(
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
).exists()
)
def test_group_create_members(self):
"""Test group create"""
user = create_test_user()
ext_id = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
data=dumps(
{
"displayName": generate_id(),
"externalId": ext_id,
"members": [{"value": str(user.uuid)}],
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 201)
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
self.assertTrue(
Event.objects.filter(
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
).exists()
)
def test_group_create_members_empty(self):
"""Test group create"""
ext_id = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
data=dumps({"displayName": generate_id(), "externalId": ext_id, "members": []}),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 201)
self.assertTrue(SCIMSourceGroup.objects.filter(source=self.source, id=ext_id).exists())
self.assertTrue(
Event.objects.filter(
action=EventAction.MODEL_CREATED, user__username=self.source.token.user.username
).exists()
)
def test_group_create_duplicate(self):
"""Test group create (duplicate)"""
group = Group.objects.create(name=generate_id())
existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
ext_id = generate_id()
response = self.client.post(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
},
),
data=dumps(
{"displayName": generate_id(), "externalId": ext_id, "id": str(existing.group.pk)}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 409)
self.assertJSONEqual(
response.content,
{
"detail": "Group with ID exists already.",
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"scimType": "uniqueness",
"status": 409,
},
)
def test_group_update(self):
"""Test group update"""
group = Group.objects.create(name=generate_id())
existing = SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
ext_id = generate_id()
response = self.client.put(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
data=dumps(
{"displayName": generate_id(), "externalId": ext_id, "id": str(existing.pk)}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=200)
def test_group_update_non_existent(self):
"""Test group update"""
ext_id = generate_id()
response = self.client.put(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={
"source_slug": self.source.slug,
"group_id": str(uuid4()),
},
),
data=dumps({"displayName": generate_id(), "externalId": ext_id, "id": ""}),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=404)
self.assertJSONEqual(
response.content,
{
"detail": "Group not found.",
"schemas": ["urn:ietf:params:scim:api:messages:2.0:Error"],
"status": 404,
},
)
def test_group_patch_add(self):
"""Test group patch"""
user = create_test_user()
group = Group.objects.create(name=generate_id())
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
response = self.client.patch(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
data=dumps(
{
"Operations": [
{
"op": "Add",
"path": "members",
"value": {"value": str(user.uuid)},
}
]
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=200)
self.assertTrue(group.users.filter(pk=user.pk).exists())
def test_group_patch_remove(self):
"""Test group patch"""
user = create_test_user()
group = Group.objects.create(name=generate_id())
group.users.add(user)
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
response = self.client.patch(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
data=dumps(
{
"Operations": [
{
"op": "remove",
"path": "members",
"value": {"value": str(user.uuid)},
}
]
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=200)
self.assertFalse(group.users.filter(pk=user.pk).exists())
def test_group_delete(self):
"""Test group delete"""
group = Group.objects.create(name=generate_id())
SCIMSourceGroup.objects.create(source=self.source, group=group, id=uuid4())
response = self.client.delete(
reverse(
"authentik_sources_scim:v2-groups",
kwargs={"source_slug": self.source.slug, "group_id": group.pk},
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, second=204)

View File

@ -177,51 +177,3 @@ class TestSCIMUsers(APITestCase):
SCIMSourceUser.objects.get(source=self.source, id=ext_id).user.attributes["phone"],
"0123456789",
)
def test_user_update(self):
"""Test user update"""
user = create_test_user()
existing = SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
ext_id = generate_id()
response = self.client.put(
reverse(
"authentik_sources_scim:v2-users",
kwargs={
"source_slug": self.source.slug,
"user_id": str(user.uuid),
},
),
data=dumps(
{
"id": str(existing.pk),
"userName": generate_id(),
"externalId": ext_id,
"emails": [
{
"primary": True,
"value": user.email,
}
],
}
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 200)
def test_user_delete(self):
"""Test user delete"""
user = create_test_user()
SCIMSourceUser.objects.create(source=self.source, user=user, id=uuid4())
response = self.client.delete(
reverse(
"authentik_sources_scim:v2-users",
kwargs={
"source_slug": self.source.slug,
"user_id": str(user.uuid),
},
),
content_type=SCIM_CONTENT_TYPE,
HTTP_AUTHORIZATION=f"Bearer {self.source.token.key}",
)
self.assertEqual(response.status_code, 204)

View File

@ -8,7 +8,6 @@ from rest_framework.authentication import BaseAuthentication, get_authorization_
from rest_framework.request import Request
from rest_framework.views import APIView
from authentik.core.middleware import CTX_AUTH_VIA
from authentik.core.models import Token, TokenIntents, User
from authentik.sources.scim.models import SCIMSource
@ -27,7 +26,6 @@ class SCIMTokenAuth(BaseAuthentication):
_username, _, password = b64decode(key.encode()).decode().partition(":")
token = self.check_token(password, source_slug)
if token:
CTX_AUTH_VIA.set("scim_basic")
return (token.user, token)
return None
@ -54,5 +52,4 @@ class SCIMTokenAuth(BaseAuthentication):
token = self.check_token(key, source_slug)
if not token:
return None
CTX_AUTH_VIA.set("scim_token")
return (token.user, token)

View File

@ -1,11 +1,13 @@
"""SCIM Utils"""
from typing import Any
from urllib.parse import urlparse
from django.conf import settings
from django.core.paginator import Page, Paginator
from django.db.models import Q, QuerySet
from django.http import HttpRequest
from django.urls import resolve
from rest_framework.parsers import JSONParser
from rest_framework.permissions import IsAuthenticated
from rest_framework.renderers import JSONRenderer
@ -44,7 +46,7 @@ class SCIMView(APIView):
logger: BoundLogger
permission_classes = [IsAuthenticated]
parser_classes = [SCIMParser, JSONParser]
parser_classes = [SCIMParser]
renderer_classes = [SCIMRenderer]
def setup(self, request: HttpRequest, *args: Any, **kwargs: Any) -> None:
@ -54,6 +56,28 @@ class SCIMView(APIView):
def get_authenticators(self):
return [SCIMTokenAuth(self)]
def patch_resolve_value(self, raw_value: dict) -> User | Group | None:
"""Attempt to resolve a raw `value` attribute of a patch operation into
a database model"""
model = User
query = {}
if "$ref" in raw_value:
url = urlparse(raw_value["$ref"])
if match := resolve(url.path):
if match.url_name == "v2-users":
model = User
query = {"pk": int(match.kwargs["user_id"])}
elif "type" in raw_value:
match raw_value["type"]:
case "User":
model = User
query = {"pk": int(raw_value["value"])}
case "Group":
model = Group
else:
return None
return model.objects.filter(**query).first()
def filter_parse(self, request: Request):
"""Parse the path of a Patch Operation"""
path = request.query_params.get("filter")

View File

@ -1,58 +0,0 @@
from enum import Enum
from pydanticscim.responses import SCIMError as BaseSCIMError
from rest_framework.exceptions import ValidationError
class SCIMErrorTypes(Enum):
invalid_filter = "invalidFilter"
too_many = "tooMany"
uniqueness = "uniqueness"
mutability = "mutability"
invalid_syntax = "invalidSyntax"
invalid_path = "invalidPath"
no_target = "noTarget"
invalid_value = "invalidValue"
invalid_vers = "invalidVers"
sensitive = "sensitive"
class SCIMError(BaseSCIMError):
scimType: SCIMErrorTypes | None = None
detail: str | None = None
class SCIMValidationError(ValidationError):
status_code = 400
default_detail = SCIMError(scimType=SCIMErrorTypes.invalid_syntax, status=400)
def __init__(self, detail: SCIMError | None):
if detail is None:
detail = self.default_detail
detail.status = self.status_code
self.detail = detail.model_dump(mode="json", exclude_none=True)
class SCIMConflictError(SCIMValidationError):
status_code = 409
def __init__(self, detail: str):
super().__init__(
SCIMError(
detail=detail,
scimType=SCIMErrorTypes.uniqueness,
status=self.status_code,
)
)
class SCIMNotFoundError(SCIMValidationError):
status_code = 404
def __init__(self, detail: str):
super().__init__(
SCIMError(
detail=detail,
status=self.status_code,
)
)

View File

@ -4,25 +4,19 @@ from uuid import uuid4
from django.db.models import Q
from django.db.transaction import atomic
from django.http import QueryDict
from django.http import Http404, QueryDict
from django.urls import reverse
from pydantic import ValidationError as PydanticValidationError
from pydanticscim.group import GroupMember
from rest_framework.exceptions import ValidationError
from rest_framework.request import Request
from rest_framework.response import Response
from scim2_filter_parser.attr_paths import AttrPath
from authentik.core.models import Group, User
from authentik.providers.scim.clients.schema import SCIM_GROUP_SCHEMA, PatchOp, PatchOperation
from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
from authentik.providers.scim.clients.schema import Group as SCIMGroupModel
from authentik.sources.scim.models import SCIMSourceGroup
from authentik.sources.scim.views.v2.base import SCIMObjectView
from authentik.sources.scim.views.v2.exceptions import (
SCIMConflictError,
SCIMNotFoundError,
SCIMValidationError,
)
class GroupsView(SCIMObjectView):
@ -33,7 +27,7 @@ class GroupsView(SCIMObjectView):
def group_to_scim(self, scim_group: SCIMSourceGroup) -> dict:
"""Convert Group to SCIM data"""
payload = SCIMGroupModel(
schemas=[SCIM_GROUP_SCHEMA],
schemas=[SCIM_USER_SCHEMA],
id=str(scim_group.group.pk),
externalId=scim_group.id,
displayName=scim_group.group.name,
@ -64,7 +58,7 @@ class GroupsView(SCIMObjectView):
if group_id:
connection = base_query.filter(source=self.source, group__group_uuid=group_id).first()
if not connection:
raise SCIMNotFoundError("Group not found.")
raise Http404
return Response(self.group_to_scim(connection))
connections = (
base_query.filter(source=self.source).order_by("pk").filter(self.filter_parse(request))
@ -125,7 +119,7 @@ class GroupsView(SCIMObjectView):
).first()
if connection:
self.logger.debug("Found existing group")
raise SCIMConflictError("Group with ID exists already.")
return Response(status=409)
connection = self.update_group(None, request.data)
return Response(self.group_to_scim(connection), status=201)
@ -135,44 +129,10 @@ class GroupsView(SCIMObjectView):
source=self.source, group__group_uuid=group_id
).first()
if not connection:
raise SCIMNotFoundError("Group not found.")
raise Http404
connection = self.update_group(connection, request.data)
return Response(self.group_to_scim(connection), status=200)
@atomic
def patch(self, request: Request, group_id: str, **kwargs) -> Response:
"""Patch group handler"""
connection = SCIMSourceGroup.objects.filter(
source=self.source, group__group_uuid=group_id
).first()
if not connection:
raise SCIMNotFoundError("Group not found.")
for _op in request.data.get("Operations", []):
operation = PatchOperation.model_validate(_op)
if operation.op.lower() not in ["add", "remove", "replace"]:
raise SCIMValidationError()
attr_path = AttrPath(f'{operation.path} eq ""', {})
if attr_path.first_path == ("members", None, None):
# FIXME: this can probably be de-duplicated
if operation.op == PatchOp.add:
if not isinstance(operation.value, list):
operation.value = [operation.value]
query = Q()
for member in operation.value:
query |= Q(uuid=member["value"])
if query:
connection.group.users.add(*User.objects.filter(query))
elif operation.op == PatchOp.remove:
if not isinstance(operation.value, list):
operation.value = [operation.value]
query = Q()
for member in operation.value:
query |= Q(uuid=member["value"])
if query:
connection.group.users.remove(*User.objects.filter(query))
return Response(self.group_to_scim(connection), status=200)
@atomic
def delete(self, request: Request, group_id: str, **kwargs) -> Response:
"""Delete group handler"""
@ -180,7 +140,7 @@ class GroupsView(SCIMObjectView):
source=self.source, group__group_uuid=group_id
).first()
if not connection:
raise SCIMNotFoundError("Group not found.")
raise Http404
connection.group.delete()
connection.delete()
return Response(status=204)

View File

@ -1,11 +1,11 @@
"""SCIM Meta views"""
from django.http import Http404
from django.urls import reverse
from rest_framework.request import Request
from rest_framework.response import Response
from authentik.sources.scim.views.v2.base import SCIMView
from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError
class ResourceTypesView(SCIMView):
@ -138,7 +138,7 @@ class ResourceTypesView(SCIMView):
resource = [x for x in resource_types if x.get("id") == resource_type]
if resource:
return Response(resource[0])
raise SCIMNotFoundError("Resource not found.")
raise Http404
return Response(
{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],

View File

@ -3,12 +3,12 @@
from json import loads
from django.conf import settings
from django.http import Http404
from django.urls import reverse
from rest_framework.request import Request
from rest_framework.response import Response
from authentik.sources.scim.views.v2.base import SCIMView
from authentik.sources.scim.views.v2.exceptions import SCIMNotFoundError
with open(
settings.BASE_DIR / "authentik" / "sources" / "scim" / "schemas" / "schema.json",
@ -44,7 +44,7 @@ class SchemaView(SCIMView):
schema = [x for x in schemas if x.get("id") == schema_uri]
if schema:
return Response(schema[0])
raise SCIMNotFoundError("Schema not found.")
raise Http404
return Response(
{
"schemas": ["urn:ietf:params:scim:api:messages:2.0:ListResponse"],

View File

@ -33,8 +33,6 @@ class ServiceProviderConfigView(SCIMView):
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"],
"authenticationSchemes": auth_schemas,
# We only support patch for groups currently, so don't broadly advertise it.
# Implementations that require Group patch will use it regardless of this flag.
"patch": {"supported": False},
"bulk": {"supported": False, "maxOperations": 0, "maxPayloadSize": 0},
"filter": {

View File

@ -4,7 +4,7 @@ from uuid import uuid4
from django.db.models import Q
from django.db.transaction import atomic
from django.http import QueryDict
from django.http import Http404, QueryDict
from django.urls import reverse
from pydanticscim.user import Email, EmailKind, Name
from rest_framework.exceptions import ValidationError
@ -16,7 +16,6 @@ from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
from authentik.providers.scim.clients.schema import User as SCIMUserModel
from authentik.sources.scim.models import SCIMSourceUser
from authentik.sources.scim.views.v2.base import SCIMObjectView
from authentik.sources.scim.views.v2.exceptions import SCIMConflictError, SCIMNotFoundError
class UsersView(SCIMObjectView):
@ -70,7 +69,7 @@ class UsersView(SCIMObjectView):
.first()
)
if not connection:
raise SCIMNotFoundError("User not found.")
raise Http404
return Response(self.user_to_scim(connection))
connections = (
SCIMSourceUser.objects.filter(source=self.source).select_related("user").order_by("pk")
@ -123,7 +122,7 @@ class UsersView(SCIMObjectView):
).first()
if connection:
self.logger.debug("Found existing user")
raise SCIMConflictError("Group with ID exists already.")
return Response(status=409)
connection = self.update_user(None, request.data)
return Response(self.user_to_scim(connection), status=201)
@ -131,7 +130,7 @@ class UsersView(SCIMObjectView):
"""Update user handler"""
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
if not connection:
raise SCIMNotFoundError("User not found.")
raise Http404
self.update_user(connection, request.data)
return Response(self.user_to_scim(connection), status=200)
@ -140,7 +139,7 @@ class UsersView(SCIMObjectView):
"""Delete user handler"""
connection = SCIMSourceUser.objects.filter(source=self.source, user__uuid=user_id).first()
if not connection:
raise SCIMNotFoundError("User not found.")
raise Http404
connection.user.delete()
connection.delete()
return Response(status=204)

View File

@ -1,7 +1,6 @@
"""Validation stage challenge checking"""
from json import loads
from typing import TYPE_CHECKING
from urllib.parse import urlencode
from django.http import HttpRequest
@ -37,12 +36,10 @@ from authentik.stages.authenticator_email.models import EmailDevice
from authentik.stages.authenticator_sms.models import SMSDevice
from authentik.stages.authenticator_validate.models import AuthenticatorValidateStage, DeviceClasses
from authentik.stages.authenticator_webauthn.models import UserVerification, WebAuthnDevice
from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.utils import get_origin, get_rp_id
LOGGER = get_logger()
if TYPE_CHECKING:
from authentik.stages.authenticator_validate.stage import AuthenticatorValidateStageView
class DeviceChallenge(PassiveSerializer):
@ -55,11 +52,11 @@ class DeviceChallenge(PassiveSerializer):
def get_challenge_for_device(
stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage, device: Device
request: HttpRequest, stage: AuthenticatorValidateStage, device: Device
) -> dict:
"""Generate challenge for a single device"""
if isinstance(device, WebAuthnDevice):
return get_webauthn_challenge(stage_view, stage, device)
return get_webauthn_challenge(request, stage, device)
if isinstance(device, EmailDevice):
return {"email": mask_email(device.email)}
# Code-based challenges have no hints
@ -67,30 +64,26 @@ def get_challenge_for_device(
def get_webauthn_challenge_without_user(
stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage
request: HttpRequest, stage: AuthenticatorValidateStage
) -> dict:
"""Same as `get_webauthn_challenge`, but allows any client device. We can then later check
who the device belongs to."""
stage_view.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None)
request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
authentication_options = generate_authentication_options(
rp_id=get_rp_id(stage_view.request),
rp_id=get_rp_id(request),
allow_credentials=[],
user_verification=UserVerificationRequirement(stage.webauthn_user_verification),
)
stage_view.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = (
authentication_options.challenge
)
request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge
return loads(options_to_json(authentication_options))
def get_webauthn_challenge(
stage_view: "AuthenticatorValidateStageView",
stage: AuthenticatorValidateStage,
device: WebAuthnDevice | None = None,
request: HttpRequest, stage: AuthenticatorValidateStage, device: WebAuthnDevice | None = None
) -> dict:
"""Send the client a challenge that we'll check later"""
stage_view.executor.plan.context.pop(PLAN_CONTEXT_WEBAUTHN_CHALLENGE, None)
request.session.pop(SESSION_KEY_WEBAUTHN_CHALLENGE, None)
allowed_credentials = []
@ -101,14 +94,12 @@ def get_webauthn_challenge(
allowed_credentials.append(user_device.descriptor)
authentication_options = generate_authentication_options(
rp_id=get_rp_id(stage_view.request),
rp_id=get_rp_id(request),
allow_credentials=allowed_credentials,
user_verification=UserVerificationRequirement(stage.webauthn_user_verification),
)
stage_view.executor.plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = (
authentication_options.challenge
)
request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = authentication_options.challenge
return loads(options_to_json(authentication_options))
@ -155,7 +146,7 @@ def validate_challenge_code(code: str, stage_view: StageView, user: User) -> Dev
def validate_challenge_webauthn(data: dict, stage_view: StageView, user: User) -> Device:
"""Validate WebAuthn Challenge"""
request = stage_view.request
challenge = stage_view.executor.plan.context.get(PLAN_CONTEXT_WEBAUTHN_CHALLENGE)
challenge = request.session.get(SESSION_KEY_WEBAUTHN_CHALLENGE)
stage: AuthenticatorValidateStage = stage_view.executor.current_stage
try:
credential = parse_authentication_credential_json(data)

View File

@ -224,7 +224,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
data={
"device_class": device_class,
"device_uid": device.pk,
"challenge": get_challenge_for_device(self, stage, device),
"challenge": get_challenge_for_device(self.request, stage, device),
"last_used": device.last_used,
}
)
@ -243,7 +243,7 @@ class AuthenticatorValidateStageView(ChallengeStageView):
"device_class": DeviceClasses.WEBAUTHN,
"device_uid": -1,
"challenge": get_webauthn_challenge_without_user(
self,
self.request,
self.executor.current_stage,
),
"last_used": None,

View File

@ -31,7 +31,7 @@ from authentik.stages.authenticator_webauthn.models import (
WebAuthnDevice,
WebAuthnDeviceType,
)
from authentik.stages.authenticator_webauthn.stage import PLAN_CONTEXT_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.stage import SESSION_KEY_WEBAUTHN_CHALLENGE
from authentik.stages.authenticator_webauthn.tasks import webauthn_mds_import
from authentik.stages.identification.models import IdentificationStage, UserFields
from authentik.stages.user_login.models import UserLoginStage
@ -103,11 +103,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
device_classes=[DeviceClasses.WEBAUTHN],
webauthn_user_verification=UserVerification.PREFERRED,
)
plan = FlowPlan("")
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request
)
challenge = get_challenge_for_device(stage_view, stage, webauthn_device)
challenge = get_challenge_for_device(request, stage, webauthn_device)
del challenge["challenge"]
self.assertEqual(
challenge,
@ -126,9 +122,7 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
with self.assertRaises(ValidationError):
validate_challenge_webauthn(
{},
StageView(FlowExecutorView(current_stage=stage, plan=plan), request=request),
self.user,
{}, StageView(FlowExecutorView(current_stage=stage), request=request), self.user
)
def test_device_challenge_webauthn_restricted(self):
@ -199,35 +193,22 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
sign_count=0,
rp_id=generate_id(),
)
plan = FlowPlan("")
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
)
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request
)
challenge = get_challenge_for_device(stage_view, stage, webauthn_device)
challenge = get_challenge_for_device(request, stage, webauthn_device)
webauthn_challenge = request.session[SESSION_KEY_WEBAUTHN_CHALLENGE]
self.assertEqual(
challenge["allowCredentials"],
[
{
"id": "QKZ97ASJAOIDyipAs6mKUxDUZgDrWrbAsUb5leL7-oU",
"type": "public-key",
}
],
)
self.assertIsNotNone(challenge["challenge"])
self.assertEqual(
challenge["rpId"],
"testserver",
)
self.assertEqual(
challenge["timeout"],
60000,
)
self.assertEqual(
challenge["userVerification"],
"preferred",
challenge,
{
"allowCredentials": [
{
"id": "QKZ97ASJAOIDyipAs6mKUxDUZgDrWrbAsUb5leL7-oU",
"type": "public-key",
}
],
"challenge": bytes_to_base64url(webauthn_challenge),
"rpId": "testserver",
"timeout": 60000,
"userVerification": "preferred",
},
)
def test_get_challenge_userless(self):
@ -247,16 +228,18 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
sign_count=0,
rp_id=generate_id(),
)
plan = FlowPlan("")
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=None, current_stage=stage, plan=plan), request=request
challenge = get_webauthn_challenge_without_user(request, stage)
webauthn_challenge = request.session[SESSION_KEY_WEBAUTHN_CHALLENGE]
self.assertEqual(
challenge,
{
"allowCredentials": [],
"challenge": bytes_to_base64url(webauthn_challenge),
"rpId": "testserver",
"timeout": 60000,
"userVerification": "preferred",
},
)
challenge = get_webauthn_challenge_without_user(stage_view, stage)
self.assertEqual(challenge["allowCredentials"], [])
self.assertIsNotNone(challenge["challenge"])
self.assertEqual(challenge["rpId"], "testserver")
self.assertEqual(challenge["timeout"], 60000)
self.assertEqual(challenge["userVerification"], "preferred")
def test_validate_challenge_unrestricted(self):
"""Test webauthn authentication (unrestricted webauthn device)"""
@ -292,10 +275,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
"last_used": None,
}
]
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ"
)
session[SESSION_KEY_PLAN] = plan
session.save()
response = self.client.post(
@ -369,10 +352,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
"last_used": None,
}
]
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"aCC6ak_DP45xMH1qyxzUM5iC2xc4QthQb09v7m4qDBmY8FvWvhxFzSuFlDYQmclrh5fWS5q0TPxgJGF4vimcFQ"
)
session[SESSION_KEY_PLAN] = plan
session.save()
response = self.client.post(
@ -450,10 +433,10 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
"last_used": None,
}
]
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
session[SESSION_KEY_PLAN] = plan
session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
)
session[SESSION_KEY_PLAN] = plan
session.save()
response = self.client.post(
@ -513,14 +496,17 @@ class AuthenticatorValidateStageWebAuthnTests(FlowTestCase):
not_configured_action=NotConfiguredAction.CONFIGURE,
device_classes=[DeviceClasses.WEBAUTHN],
)
plan = FlowPlan(flow.pk.hex)
plan.context[PLAN_CONTEXT_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=flow, current_stage=stage), request=request
)
request = get_request("/")
request.session[SESSION_KEY_WEBAUTHN_CHALLENGE] = base64url_to_bytes(
"g98I51mQvZXo5lxLfhrD2zfolhZbLRyCgqkkYap1jwSaJ13BguoJWCF9_Lg3AgO4Wh-Bqa556JE20oKsYbl6RA"
)
request.session.save()
stage_view = AuthenticatorValidateStageView(
FlowExecutorView(flow=flow, current_stage=stage, plan=plan), request=request
FlowExecutorView(flow=flow, current_stage=stage), request=request
)
request.META["SERVER_NAME"] = "localhost"
request.META["SERVER_PORT"] = "9000"

Some files were not shown because too many files have changed in this diff Show More