Compare commits
	
		
			50 Commits
		
	
	
		
			version/20
			...
			version/20
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 1a21479b0d | |||
| 38154f72e0 | |||
| 19318d4c00 | |||
| be3d7c0666 | |||
| 5afceaa55f | |||
| 72dc27f1c9 | |||
| b5ffd16861 | |||
| 8af754e88c | |||
| ade1f08c89 | |||
| 9240fa1037 | |||
| 1f5953b5b7 | |||
| 5befccc1fd | |||
| ff193d809a | |||
| 23bbb6e5ef | |||
| 225d02d02d | |||
| 90fe1eda66 | |||
| 35ba88a203 | |||
| 8414a9dcad | |||
| 1d626f5b57 | |||
| 508dd0ac64 | |||
| f4b82a8b09 | |||
| 2900f01976 | |||
| 0f6ece5eb7 | |||
| b9936fe532 | |||
| d0b3cc5916 | |||
| e034f5e5dc | |||
| 9d6816bbc8 | |||
| 82d4ea9e8a | |||
| c8a804f2a7 | |||
| ca70c963e5 | |||
| 4c89d4a4a4 | |||
| 8a47acac3a | |||
| 4a3b22491c | |||
| f991d656c7 | |||
| e86aa11131 | |||
| 03725ae086 | |||
| f2a37e8c7c | |||
| e935690b1b | |||
| 02709e4ede | |||
| f78adab9d1 | |||
| 61f3a72fd9 | |||
| 541becfe30 | |||
| 11ff7955f7 | |||
| afa4234036 | |||
| ca22a4deaf | |||
| 7b7a3d34ec | |||
| b1ca579397 | |||
| c8072579c8 | |||
| 378a701fb9 | |||
| bba793d94c | 
@ -1,5 +1,5 @@
 | 
			
		||||
[bumpversion]
 | 
			
		||||
current_version = 2024.2.3
 | 
			
		||||
current_version = 2024.4.4
 | 
			
		||||
tag = True
 | 
			
		||||
commit = True
 | 
			
		||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))?
 | 
			
		||||
@ -21,6 +21,8 @@ optional_value = final
 | 
			
		||||
 | 
			
		||||
[bumpversion:file:schema.yml]
 | 
			
		||||
 | 
			
		||||
[bumpversion:file:blueprints/schema.json]
 | 
			
		||||
 | 
			
		||||
[bumpversion:file:authentik/__init__.py]
 | 
			
		||||
 | 
			
		||||
[bumpversion:file:internal/constants/constants.go]
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,7 @@ should_build = str(os.environ.get("DOCKER_USERNAME", None) is not None).lower()
 | 
			
		||||
branch_name = os.environ["GITHUB_REF"]
 | 
			
		||||
if os.environ.get("GITHUB_HEAD_REF", "") != "":
 | 
			
		||||
    branch_name = os.environ["GITHUB_HEAD_REF"]
 | 
			
		||||
safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-")
 | 
			
		||||
safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-").replace("'", "-")
 | 
			
		||||
 | 
			
		||||
image_names = os.getenv("IMAGE_NAME").split(",")
 | 
			
		||||
image_arch = os.getenv("IMAGE_ARCH") or None
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							@ -34,6 +34,13 @@ jobs:
 | 
			
		||||
      - name: Eslint
 | 
			
		||||
        working-directory: ${{ matrix.project }}/
 | 
			
		||||
        run: npm run lint
 | 
			
		||||
  lint-lockfile:
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    steps:
 | 
			
		||||
      - uses: actions/checkout@v4
 | 
			
		||||
      - working-directory: web/
 | 
			
		||||
        run: |
 | 
			
		||||
          [ -z "$(jq -r '.packages | to_entries[] | select((.key | startswith("node_modules")) and (.value | has("resolved") | not)) | .key' < package-lock.json)" ]
 | 
			
		||||
  lint-build:
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    steps:
 | 
			
		||||
@ -95,6 +102,7 @@ jobs:
 | 
			
		||||
        run: npm run lit-analyse
 | 
			
		||||
  ci-web-mark:
 | 
			
		||||
    needs:
 | 
			
		||||
      - lint-lockfile
 | 
			
		||||
      - lint-eslint
 | 
			
		||||
      - lint-prettier
 | 
			
		||||
      - lint-lit-analyse
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							@ -12,6 +12,13 @@ on:
 | 
			
		||||
      - version-*
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  lint-lockfile:
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    steps:
 | 
			
		||||
      - uses: actions/checkout@v4
 | 
			
		||||
      - working-directory: website/
 | 
			
		||||
        run: |
 | 
			
		||||
          [ -z "$(jq -r '.packages | to_entries[] | select((.key | startswith("node_modules")) and (.value | has("resolved") | not)) | .key' < package-lock.json)" ]
 | 
			
		||||
  lint-prettier:
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    steps:
 | 
			
		||||
@ -62,6 +69,7 @@ jobs:
 | 
			
		||||
        run: npm run ${{ matrix.job }}
 | 
			
		||||
  ci-website-mark:
 | 
			
		||||
    needs:
 | 
			
		||||
      - lint-lockfile
 | 
			
		||||
      - lint-prettier
 | 
			
		||||
      - test
 | 
			
		||||
      - build
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,7 @@
 | 
			
		||||
 | 
			
		||||
from os import environ
 | 
			
		||||
 | 
			
		||||
__version__ = "2024.2.3"
 | 
			
		||||
__version__ = "2024.4.4"
 | 
			
		||||
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -154,12 +154,18 @@ class GroupViewSet(UsedByMixin, ModelViewSet):
 | 
			
		||||
 | 
			
		||||
        pk = IntegerField(required=True)
 | 
			
		||||
 | 
			
		||||
    queryset = Group.objects.all().select_related("parent").prefetch_related("users")
 | 
			
		||||
    queryset = Group.objects.none()
 | 
			
		||||
    serializer_class = GroupSerializer
 | 
			
		||||
    search_fields = ["name", "is_superuser"]
 | 
			
		||||
    filterset_class = GroupFilter
 | 
			
		||||
    ordering = ["name"]
 | 
			
		||||
 | 
			
		||||
    def get_queryset(self):
 | 
			
		||||
        base_qs = Group.objects.all().select_related("parent").prefetch_related("roles")
 | 
			
		||||
        if self.serializer_class(context={"request": self.request})._should_include_users:
 | 
			
		||||
            base_qs = base_qs.prefetch_related("users")
 | 
			
		||||
        return base_qs
 | 
			
		||||
 | 
			
		||||
    @extend_schema(
 | 
			
		||||
        parameters=[
 | 
			
		||||
            OpenApiParameter("include_users", bool, default=True),
 | 
			
		||||
 | 
			
		||||
@ -2,6 +2,7 @@
 | 
			
		||||
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from django.utils.timezone import now
 | 
			
		||||
from django_filters.rest_framework import DjangoFilterBackend
 | 
			
		||||
from drf_spectacular.utils import OpenApiResponse, extend_schema, inline_serializer
 | 
			
		||||
from guardian.shortcuts import assign_perm, get_anonymous_user
 | 
			
		||||
@ -27,7 +28,6 @@ from authentik.core.models import (
 | 
			
		||||
    TokenIntents,
 | 
			
		||||
    User,
 | 
			
		||||
    default_token_duration,
 | 
			
		||||
    token_expires_from_timedelta,
 | 
			
		||||
)
 | 
			
		||||
from authentik.events.models import Event, EventAction
 | 
			
		||||
from authentik.events.utils import model_to_dict
 | 
			
		||||
@ -45,6 +45,13 @@ class TokenSerializer(ManagedSerializer, ModelSerializer):
 | 
			
		||||
        if SERIALIZER_CONTEXT_BLUEPRINT in self.context:
 | 
			
		||||
            self.fields["key"] = CharField(required=False)
 | 
			
		||||
 | 
			
		||||
    def validate_user(self, user: User):
 | 
			
		||||
        """Ensure user of token cannot be changed"""
 | 
			
		||||
        if self.instance and self.instance.user_id:
 | 
			
		||||
            if user.pk != self.instance.user_id:
 | 
			
		||||
                raise ValidationError("User cannot be changed")
 | 
			
		||||
        return user
 | 
			
		||||
 | 
			
		||||
    def validate(self, attrs: dict[Any, str]) -> dict[Any, str]:
 | 
			
		||||
        """Ensure only API or App password tokens are created."""
 | 
			
		||||
        request: Request = self.context.get("request")
 | 
			
		||||
@ -68,15 +75,17 @@ class TokenSerializer(ManagedSerializer, ModelSerializer):
 | 
			
		||||
            max_token_lifetime_dt = default_token_duration()
 | 
			
		||||
            if max_token_lifetime is not None:
 | 
			
		||||
                try:
 | 
			
		||||
                    max_token_lifetime_dt = timedelta_from_string(max_token_lifetime)
 | 
			
		||||
                    max_token_lifetime_dt = now() + timedelta_from_string(max_token_lifetime)
 | 
			
		||||
                except ValueError:
 | 
			
		||||
                    max_token_lifetime_dt = default_token_duration()
 | 
			
		||||
                    pass
 | 
			
		||||
 | 
			
		||||
            if "expires" in attrs and attrs.get("expires") > token_expires_from_timedelta(
 | 
			
		||||
                max_token_lifetime_dt
 | 
			
		||||
            ):
 | 
			
		||||
            if "expires" in attrs and attrs.get("expires") > max_token_lifetime_dt:
 | 
			
		||||
                raise ValidationError(
 | 
			
		||||
                    {"expires": f"Token expires exceeds maximum lifetime ({max_token_lifetime})."}
 | 
			
		||||
                    {
 | 
			
		||||
                        "expires": (
 | 
			
		||||
                            f"Token expires exceeds maximum lifetime ({max_token_lifetime_dt} UTC)."
 | 
			
		||||
                        )
 | 
			
		||||
                    }
 | 
			
		||||
                )
 | 
			
		||||
        elif attrs.get("intent") == TokenIntents.INTENT_API:
 | 
			
		||||
            # For API tokens, expires cannot be overridden
 | 
			
		||||
 | 
			
		||||
@ -14,6 +14,7 @@ from rest_framework.request import Request
 | 
			
		||||
from rest_framework.response import Response
 | 
			
		||||
 | 
			
		||||
from authentik.core.api.utils import PassiveSerializer
 | 
			
		||||
from authentik.rbac.filters import ObjectFilter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeleteAction(Enum):
 | 
			
		||||
@ -53,7 +54,7 @@ class UsedByMixin:
 | 
			
		||||
    @extend_schema(
 | 
			
		||||
        responses={200: UsedBySerializer(many=True)},
 | 
			
		||||
    )
 | 
			
		||||
    @action(detail=True, pagination_class=None, filter_backends=[])
 | 
			
		||||
    @action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
 | 
			
		||||
    def used_by(self, request: Request, *args, **kwargs) -> Response:
 | 
			
		||||
        """Get a list of all objects that use this object"""
 | 
			
		||||
        model: Model = self.get_object()
 | 
			
		||||
 | 
			
		||||
@ -407,8 +407,11 @@ class UserViewSet(UsedByMixin, ModelViewSet):
 | 
			
		||||
    search_fields = ["username", "name", "is_active", "email", "uuid"]
 | 
			
		||||
    filterset_class = UsersFilter
 | 
			
		||||
 | 
			
		||||
    def get_queryset(self):  # pragma: no cover
 | 
			
		||||
        return User.objects.all().exclude_anonymous().prefetch_related("ak_groups")
 | 
			
		||||
    def get_queryset(self):
 | 
			
		||||
        base_qs = User.objects.all().exclude_anonymous()
 | 
			
		||||
        if self.serializer_class(context={"request": self.request})._should_include_groups:
 | 
			
		||||
            base_qs = base_qs.prefetch_related("ak_groups")
 | 
			
		||||
        return base_qs
 | 
			
		||||
 | 
			
		||||
    @extend_schema(
 | 
			
		||||
        parameters=[
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
"""authentik core models"""
 | 
			
		||||
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
from hashlib import sha256
 | 
			
		||||
from typing import Any, Optional, Self
 | 
			
		||||
from uuid import uuid4
 | 
			
		||||
@ -68,11 +68,6 @@ def default_token_duration() -> datetime:
 | 
			
		||||
    return now() + timedelta_from_string(token_duration)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def token_expires_from_timedelta(dt: timedelta) -> datetime:
 | 
			
		||||
    """Return a `datetime.datetime` object with the duration of the Token"""
 | 
			
		||||
    return now() + dt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def default_token_key() -> str:
 | 
			
		||||
    """Default token key"""
 | 
			
		||||
    current_tenant = get_current_tenant()
 | 
			
		||||
@ -637,7 +632,7 @@ class UserSourceConnection(SerializerModel, CreatedUpdatedModel):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        return f"User-source connection (user={self.user.username}, source={self.source.slug})"
 | 
			
		||||
        return f"User-source connection (user={self.user_id}, source={self.source_id})"
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        unique_together = (("user", "source"),)
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,7 @@ from django.utils.translation import gettext as _
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection
 | 
			
		||||
from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostUserEnrollmentStage
 | 
			
		||||
from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostSourceStage
 | 
			
		||||
from authentik.events.models import Event, EventAction
 | 
			
		||||
from authentik.flows.exceptions import FlowNonApplicableException
 | 
			
		||||
from authentik.flows.models import Flow, FlowToken, Stage, in_memory_stage
 | 
			
		||||
@ -100,8 +100,6 @@ class SourceFlowManager:
 | 
			
		||||
        if self.request.user.is_authenticated:
 | 
			
		||||
            new_connection.user = self.request.user
 | 
			
		||||
            new_connection = self.update_connection(new_connection, **kwargs)
 | 
			
		||||
 | 
			
		||||
            new_connection.save()
 | 
			
		||||
            return Action.LINK, new_connection
 | 
			
		||||
 | 
			
		||||
        existing_connections = self.connection_type.objects.filter(
 | 
			
		||||
@ -148,7 +146,6 @@ class SourceFlowManager:
 | 
			
		||||
        ]:
 | 
			
		||||
            new_connection.user = user
 | 
			
		||||
            new_connection = self.update_connection(new_connection, **kwargs)
 | 
			
		||||
            new_connection.save()
 | 
			
		||||
            return Action.LINK, new_connection
 | 
			
		||||
        if self.source.user_matching_mode in [
 | 
			
		||||
            SourceUserMatchingModes.EMAIL_DENY,
 | 
			
		||||
@ -209,13 +206,9 @@ class SourceFlowManager:
 | 
			
		||||
 | 
			
		||||
    def get_stages_to_append(self, flow: Flow) -> list[Stage]:
 | 
			
		||||
        """Hook to override stages which are appended to the flow"""
 | 
			
		||||
        if not self.source.enrollment_flow:
 | 
			
		||||
            return []
 | 
			
		||||
        if flow.slug == self.source.enrollment_flow.slug:
 | 
			
		||||
            return [
 | 
			
		||||
                in_memory_stage(PostUserEnrollmentStage),
 | 
			
		||||
            ]
 | 
			
		||||
        return []
 | 
			
		||||
        return [
 | 
			
		||||
            in_memory_stage(PostSourceStage),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def _prepare_flow(
 | 
			
		||||
        self,
 | 
			
		||||
@ -269,6 +262,9 @@ class SourceFlowManager:
 | 
			
		||||
            )
 | 
			
		||||
        # We run the Flow planner here so we can pass the Pending user in the context
 | 
			
		||||
        planner = FlowPlanner(flow)
 | 
			
		||||
        # We append some stages so the initial flow we get might be empty
 | 
			
		||||
        planner.allow_empty_flows = True
 | 
			
		||||
        planner.use_cache = False
 | 
			
		||||
        plan = planner.plan(self.request, kwargs)
 | 
			
		||||
        for stage in self.get_stages_to_append(flow):
 | 
			
		||||
            plan.append_stage(stage)
 | 
			
		||||
@ -327,7 +323,7 @@ class SourceFlowManager:
 | 
			
		||||
            reverse(
 | 
			
		||||
                "authentik_core:if-user",
 | 
			
		||||
            )
 | 
			
		||||
            + f"#/settings;page-{self.source.slug}"
 | 
			
		||||
            + "#/settings;page-sources"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def handle_enroll(
 | 
			
		||||
 | 
			
		||||
@ -10,7 +10,7 @@ from authentik.flows.stage import StageView
 | 
			
		||||
PLAN_CONTEXT_SOURCES_CONNECTION = "goauthentik.io/sources/connection"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PostUserEnrollmentStage(StageView):
 | 
			
		||||
class PostSourceStage(StageView):
 | 
			
		||||
    """Dynamically injected stage which saves the Connection after
 | 
			
		||||
    the user has been enrolled."""
 | 
			
		||||
 | 
			
		||||
@ -21,10 +21,12 @@ class PostUserEnrollmentStage(StageView):
 | 
			
		||||
        ]
 | 
			
		||||
        user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
 | 
			
		||||
        connection.user = user
 | 
			
		||||
        linked = connection.pk is None
 | 
			
		||||
        connection.save()
 | 
			
		||||
        Event.new(
 | 
			
		||||
            EventAction.SOURCE_LINKED,
 | 
			
		||||
            message="Linked Source",
 | 
			
		||||
            source=connection.source,
 | 
			
		||||
        ).from_http(self.request)
 | 
			
		||||
        if linked:
 | 
			
		||||
            Event.new(
 | 
			
		||||
                EventAction.SOURCE_LINKED,
 | 
			
		||||
                message="Linked Source",
 | 
			
		||||
                source=connection.source,
 | 
			
		||||
            ).from_http(self.request)
 | 
			
		||||
        return self.executor.stage_ok()
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,9 @@
 | 
			
		||||
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
 | 
			
		||||
from django.conf import ImproperlyConfigured
 | 
			
		||||
from django.contrib.sessions.backends.cache import KEY_PREFIX
 | 
			
		||||
from django.contrib.sessions.backends.db import SessionStore as DBSessionStore
 | 
			
		||||
from django.core.cache import cache
 | 
			
		||||
from django.utils.timezone import now
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
@ -15,6 +17,7 @@ from authentik.core.models import (
 | 
			
		||||
    User,
 | 
			
		||||
)
 | 
			
		||||
from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task
 | 
			
		||||
from authentik.lib.config import CONFIG
 | 
			
		||||
from authentik.root.celery import CELERY_APP
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
@ -39,16 +42,31 @@ def clean_expired_models(self: SystemTask):
 | 
			
		||||
    amount = 0
 | 
			
		||||
 | 
			
		||||
    for session in AuthenticatedSession.objects.all():
 | 
			
		||||
        cache_key = f"{KEY_PREFIX}{session.session_key}"
 | 
			
		||||
        value = None
 | 
			
		||||
        try:
 | 
			
		||||
            value = cache.get(cache_key)
 | 
			
		||||
        match CONFIG.get("session_storage", "cache"):
 | 
			
		||||
            case "cache":
 | 
			
		||||
                cache_key = f"{KEY_PREFIX}{session.session_key}"
 | 
			
		||||
                value = None
 | 
			
		||||
                try:
 | 
			
		||||
                    value = cache.get(cache_key)
 | 
			
		||||
 | 
			
		||||
        except Exception as exc:
 | 
			
		||||
            LOGGER.debug("Failed to get session from cache", exc=exc)
 | 
			
		||||
        if not value:
 | 
			
		||||
            session.delete()
 | 
			
		||||
            amount += 1
 | 
			
		||||
                except Exception as exc:
 | 
			
		||||
                    LOGGER.debug("Failed to get session from cache", exc=exc)
 | 
			
		||||
                if not value:
 | 
			
		||||
                    session.delete()
 | 
			
		||||
                    amount += 1
 | 
			
		||||
            case "db":
 | 
			
		||||
                if not (
 | 
			
		||||
                    DBSessionStore.get_model_class()
 | 
			
		||||
                    .objects.filter(session_key=session.session_key, expire_date__gt=now())
 | 
			
		||||
                    .exists()
 | 
			
		||||
                ):
 | 
			
		||||
                    session.delete()
 | 
			
		||||
                    amount += 1
 | 
			
		||||
            case _:
 | 
			
		||||
                # Should never happen, as we check for other values in authentik/root/settings.py
 | 
			
		||||
                raise ImproperlyConfigured(
 | 
			
		||||
                    "Invalid session_storage setting, allowed values are db and cache"
 | 
			
		||||
                )
 | 
			
		||||
    LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount)
 | 
			
		||||
 | 
			
		||||
    messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}")
 | 
			
		||||
 | 
			
		||||
@ -5,7 +5,7 @@ from guardian.shortcuts import assign_perm
 | 
			
		||||
from rest_framework.test import APITestCase
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import Group, User
 | 
			
		||||
from authentik.core.tests.utils import create_test_user
 | 
			
		||||
from authentik.core.tests.utils import create_test_admin_user, create_test_user
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,13 @@ class TestGroupsAPI(APITestCase):
 | 
			
		||||
        self.login_user = create_test_user()
 | 
			
		||||
        self.user = User.objects.create(username="test-user")
 | 
			
		||||
 | 
			
		||||
    def test_list_with_users(self):
 | 
			
		||||
        """Test listing with users"""
 | 
			
		||||
        admin = create_test_admin_user()
 | 
			
		||||
        self.client.force_login(admin)
 | 
			
		||||
        response = self.client.get(reverse("authentik_api:group-list"), {"include_users": "true"})
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
    def test_add_user(self):
 | 
			
		||||
        """Test add_user"""
 | 
			
		||||
        group = Group.objects.create(name=generate_id())
 | 
			
		||||
 | 
			
		||||
@ -2,11 +2,15 @@
 | 
			
		||||
 | 
			
		||||
from django.contrib.auth.models import AnonymousUser
 | 
			
		||||
from django.test import TestCase
 | 
			
		||||
from django.urls import reverse
 | 
			
		||||
from guardian.utils import get_anonymous_user
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import SourceUserMatchingModes, User
 | 
			
		||||
from authentik.core.sources.flow_manager import Action
 | 
			
		||||
from authentik.core.sources.stage import PostSourceStage
 | 
			
		||||
from authentik.core.tests.utils import create_test_flow
 | 
			
		||||
from authentik.flows.planner import FlowPlan
 | 
			
		||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
from authentik.lib.tests.utils import get_request
 | 
			
		||||
from authentik.policies.denied import AccessDeniedResponse
 | 
			
		||||
@ -21,42 +25,62 @@ class TestSourceFlowManager(TestCase):
 | 
			
		||||
 | 
			
		||||
    def setUp(self) -> None:
 | 
			
		||||
        super().setUp()
 | 
			
		||||
        self.source: OAuthSource = OAuthSource.objects.create(name="test")
 | 
			
		||||
        self.authentication_flow = create_test_flow()
 | 
			
		||||
        self.enrollment_flow = create_test_flow()
 | 
			
		||||
        self.source: OAuthSource = OAuthSource.objects.create(
 | 
			
		||||
            name=generate_id(),
 | 
			
		||||
            slug=generate_id(),
 | 
			
		||||
            authentication_flow=self.authentication_flow,
 | 
			
		||||
            enrollment_flow=self.enrollment_flow,
 | 
			
		||||
        )
 | 
			
		||||
        self.identifier = generate_id()
 | 
			
		||||
 | 
			
		||||
    def test_unauthenticated_enroll(self):
 | 
			
		||||
        """Test un-authenticated user enrolling"""
 | 
			
		||||
        flow_manager = OAuthSourceFlowManager(
 | 
			
		||||
            self.source, get_request("/", user=AnonymousUser()), self.identifier, {}
 | 
			
		||||
        )
 | 
			
		||||
        request = get_request("/", user=AnonymousUser())
 | 
			
		||||
        flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {})
 | 
			
		||||
        action, _ = flow_manager.get_action()
 | 
			
		||||
        self.assertEqual(action, Action.ENROLL)
 | 
			
		||||
        flow_manager.get_flow()
 | 
			
		||||
        response = flow_manager.get_flow()
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        flow_plan: FlowPlan = request.session[SESSION_KEY_PLAN]
 | 
			
		||||
        self.assertEqual(flow_plan.bindings[0].stage.view, PostSourceStage)
 | 
			
		||||
 | 
			
		||||
    def test_unauthenticated_auth(self):
 | 
			
		||||
        """Test un-authenticated user authenticating"""
 | 
			
		||||
        UserOAuthSourceConnection.objects.create(
 | 
			
		||||
            user=get_anonymous_user(), source=self.source, identifier=self.identifier
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        flow_manager = OAuthSourceFlowManager(
 | 
			
		||||
            self.source, get_request("/", user=AnonymousUser()), self.identifier, {}
 | 
			
		||||
        )
 | 
			
		||||
        request = get_request("/", user=AnonymousUser())
 | 
			
		||||
        flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {})
 | 
			
		||||
        action, _ = flow_manager.get_action()
 | 
			
		||||
        self.assertEqual(action, Action.AUTH)
 | 
			
		||||
        flow_manager.get_flow()
 | 
			
		||||
        response = flow_manager.get_flow()
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        flow_plan: FlowPlan = request.session[SESSION_KEY_PLAN]
 | 
			
		||||
        self.assertEqual(flow_plan.bindings[0].stage.view, PostSourceStage)
 | 
			
		||||
 | 
			
		||||
    def test_authenticated_link(self):
 | 
			
		||||
        """Test authenticated user linking"""
 | 
			
		||||
        UserOAuthSourceConnection.objects.create(
 | 
			
		||||
            user=get_anonymous_user(), source=self.source, identifier=self.identifier
 | 
			
		||||
        )
 | 
			
		||||
        user = User.objects.create(username="foo", email="foo@bar.baz")
 | 
			
		||||
        flow_manager = OAuthSourceFlowManager(
 | 
			
		||||
            self.source, get_request("/", user=user), self.identifier, {}
 | 
			
		||||
        )
 | 
			
		||||
        action, _ = flow_manager.get_action()
 | 
			
		||||
        request = get_request("/", user=user)
 | 
			
		||||
        flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {})
 | 
			
		||||
        action, connection = flow_manager.get_action()
 | 
			
		||||
        self.assertEqual(action, Action.LINK)
 | 
			
		||||
        self.assertIsNone(connection.pk)
 | 
			
		||||
        response = flow_manager.get_flow()
 | 
			
		||||
        self.assertEqual(response.status_code, 302)
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            response.url,
 | 
			
		||||
            reverse("authentik_core:if-user") + "#/settings;page-sources",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_unauthenticated_link(self):
 | 
			
		||||
        """Test un-authenticated user linking"""
 | 
			
		||||
        flow_manager = OAuthSourceFlowManager(self.source, get_request("/"), self.identifier, {})
 | 
			
		||||
        action, connection = flow_manager.get_action()
 | 
			
		||||
        self.assertEqual(action, Action.LINK)
 | 
			
		||||
        self.assertIsNone(connection.pk)
 | 
			
		||||
        flow_manager.get_flow()
 | 
			
		||||
 | 
			
		||||
    def test_unauthenticated_enroll_email(self):
 | 
			
		||||
 | 
			
		||||
@ -13,9 +13,8 @@ from authentik.core.models import (
 | 
			
		||||
    USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME,
 | 
			
		||||
    Token,
 | 
			
		||||
    TokenIntents,
 | 
			
		||||
    User,
 | 
			
		||||
)
 | 
			
		||||
from authentik.core.tests.utils import create_test_admin_user
 | 
			
		||||
from authentik.core.tests.utils import create_test_admin_user, create_test_user
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -24,7 +23,7 @@ class TestTokenAPI(APITestCase):
 | 
			
		||||
 | 
			
		||||
    def setUp(self) -> None:
 | 
			
		||||
        super().setUp()
 | 
			
		||||
        self.user = User.objects.create(username="testuser")
 | 
			
		||||
        self.user = create_test_user()
 | 
			
		||||
        self.admin = create_test_admin_user()
 | 
			
		||||
        self.client.force_login(self.user)
 | 
			
		||||
 | 
			
		||||
@ -154,6 +153,24 @@ class TestTokenAPI(APITestCase):
 | 
			
		||||
        self.assertEqual(token.expiring, True)
 | 
			
		||||
        self.assertNotEqual(token.expires.timestamp(), expires.timestamp())
 | 
			
		||||
 | 
			
		||||
    def test_token_change_user(self):
 | 
			
		||||
        """Test creating a token and then changing the user"""
 | 
			
		||||
        ident = generate_id()
 | 
			
		||||
        response = self.client.post(reverse("authentik_api:token-list"), {"identifier": ident})
 | 
			
		||||
        self.assertEqual(response.status_code, 201)
 | 
			
		||||
        token = Token.objects.get(identifier=ident)
 | 
			
		||||
        self.assertEqual(token.user, self.user)
 | 
			
		||||
        self.assertEqual(token.intent, TokenIntents.INTENT_API)
 | 
			
		||||
        self.assertEqual(token.expiring, True)
 | 
			
		||||
        self.assertTrue(self.user.has_perm("authentik_core.view_token_key", token))
 | 
			
		||||
        response = self.client.put(
 | 
			
		||||
            reverse("authentik_api:token-detail", kwargs={"identifier": ident}),
 | 
			
		||||
            data={"identifier": "user_token_poc_v3", "intent": "api", "user": self.admin.pk},
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 400)
 | 
			
		||||
        token.refresh_from_db()
 | 
			
		||||
        self.assertEqual(token.user, self.user)
 | 
			
		||||
 | 
			
		||||
    def test_list(self):
 | 
			
		||||
        """Test Token List (Test normal authentication)"""
 | 
			
		||||
        Token.objects.all().delete()
 | 
			
		||||
 | 
			
		||||
@ -41,6 +41,12 @@ class TestUsersAPI(APITestCase):
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
    def test_list_with_groups(self):
 | 
			
		||||
        """Test listing with groups"""
 | 
			
		||||
        self.client.force_login(self.admin)
 | 
			
		||||
        response = self.client.get(reverse("authentik_api:user-list"), {"include_groups": "true"})
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
    def test_metrics(self):
 | 
			
		||||
        """Test user's metrics"""
 | 
			
		||||
        self.client.force_login(self.admin)
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,6 @@ from rest_framework.test import APITestCase
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import User
 | 
			
		||||
from authentik.core.tests.utils import create_test_admin_user
 | 
			
		||||
from authentik.lib.config import CONFIG
 | 
			
		||||
from authentik.tenants.utils import get_current_tenant
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -25,7 +24,6 @@ class TestUsersAvatars(APITestCase):
 | 
			
		||||
        tenant.avatars = mode
 | 
			
		||||
        tenant.save()
 | 
			
		||||
 | 
			
		||||
    @CONFIG.patch("avatars", "none")
 | 
			
		||||
    def test_avatars_none(self):
 | 
			
		||||
        """Test avatars none"""
 | 
			
		||||
        self.set_avatar_mode("none")
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,7 @@ from django.utils.text import slugify
 | 
			
		||||
 | 
			
		||||
from authentik.brands.models import Brand
 | 
			
		||||
from authentik.core.models import Group, User
 | 
			
		||||
from authentik.crypto.builder import CertificateBuilder
 | 
			
		||||
from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg
 | 
			
		||||
from authentik.crypto.models import CertificateKeyPair
 | 
			
		||||
from authentik.flows.models import Flow, FlowDesignation
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
@ -50,12 +50,10 @@ def create_test_brand(**kwargs) -> Brand:
 | 
			
		||||
    return Brand.objects.create(domain=uid, default=True, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_test_cert(use_ec_private_key=False) -> CertificateKeyPair:
 | 
			
		||||
def create_test_cert(alg=PrivateKeyAlg.RSA) -> CertificateKeyPair:
 | 
			
		||||
    """Generate a certificate for testing"""
 | 
			
		||||
    builder = CertificateBuilder(
 | 
			
		||||
        name=f"{generate_id()}.self-signed.goauthentik.io",
 | 
			
		||||
        use_ec_private_key=use_ec_private_key,
 | 
			
		||||
    )
 | 
			
		||||
    builder = CertificateBuilder(f"{generate_id()}.self-signed.goauthentik.io")
 | 
			
		||||
    builder.alg = alg
 | 
			
		||||
    builder.build(
 | 
			
		||||
        subject_alt_names=[f"{generate_id()}.self-signed.goauthentik.io"],
 | 
			
		||||
        validity_days=360,
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,13 @@ from drf_spectacular.types import OpenApiTypes
 | 
			
		||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
 | 
			
		||||
from rest_framework.decorators import action
 | 
			
		||||
from rest_framework.exceptions import ValidationError
 | 
			
		||||
from rest_framework.fields import CharField, DateTimeField, IntegerField, SerializerMethodField
 | 
			
		||||
from rest_framework.fields import (
 | 
			
		||||
    CharField,
 | 
			
		||||
    ChoiceField,
 | 
			
		||||
    DateTimeField,
 | 
			
		||||
    IntegerField,
 | 
			
		||||
    SerializerMethodField,
 | 
			
		||||
)
 | 
			
		||||
from rest_framework.filters import OrderingFilter, SearchFilter
 | 
			
		||||
from rest_framework.request import Request
 | 
			
		||||
from rest_framework.response import Response
 | 
			
		||||
@ -26,10 +32,11 @@ from authentik.api.authorization import SecretKeyFilter
 | 
			
		||||
from authentik.core.api.used_by import UsedByMixin
 | 
			
		||||
from authentik.core.api.utils import PassiveSerializer
 | 
			
		||||
from authentik.crypto.apps import MANAGED_KEY
 | 
			
		||||
from authentik.crypto.builder import CertificateBuilder
 | 
			
		||||
from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg
 | 
			
		||||
from authentik.crypto.models import CertificateKeyPair
 | 
			
		||||
from authentik.events.models import Event, EventAction
 | 
			
		||||
from authentik.rbac.decorators import permission_required
 | 
			
		||||
from authentik.rbac.filters import ObjectFilter
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
@ -178,6 +185,7 @@ class CertificateGenerationSerializer(PassiveSerializer):
 | 
			
		||||
    common_name = CharField()
 | 
			
		||||
    subject_alt_name = CharField(required=False, allow_blank=True, label=_("Subject-alt name"))
 | 
			
		||||
    validity_days = IntegerField(initial=365)
 | 
			
		||||
    alg = ChoiceField(default=PrivateKeyAlg.RSA, choices=PrivateKeyAlg.choices)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CertificateKeyPairFilter(FilterSet):
 | 
			
		||||
@ -240,6 +248,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
 | 
			
		||||
        raw_san = data.validated_data.get("subject_alt_name", "")
 | 
			
		||||
        sans = raw_san.split(",") if raw_san != "" else []
 | 
			
		||||
        builder = CertificateBuilder(data.validated_data["common_name"])
 | 
			
		||||
        builder.alg = data.validated_data["alg"]
 | 
			
		||||
        builder.build(
 | 
			
		||||
            subject_alt_names=sans,
 | 
			
		||||
            validity_days=int(data.validated_data["validity_days"]),
 | 
			
		||||
@ -258,7 +267,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
 | 
			
		||||
        ],
 | 
			
		||||
        responses={200: CertificateDataSerializer(many=False)},
 | 
			
		||||
    )
 | 
			
		||||
    @action(detail=True, pagination_class=None, filter_backends=[])
 | 
			
		||||
    @action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
 | 
			
		||||
    def view_certificate(self, request: Request, pk: str) -> Response:
 | 
			
		||||
        """Return certificate-key pairs certificate and log access"""
 | 
			
		||||
        certificate: CertificateKeyPair = self.get_object()
 | 
			
		||||
@ -288,7 +297,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet):
 | 
			
		||||
        ],
 | 
			
		||||
        responses={200: CertificateDataSerializer(many=False)},
 | 
			
		||||
    )
 | 
			
		||||
    @action(detail=True, pagination_class=None, filter_backends=[])
 | 
			
		||||
    @action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
 | 
			
		||||
    def view_private_key(self, request: Request, pk: str) -> Response:
 | 
			
		||||
        """Return certificate-key pairs private key and log access"""
 | 
			
		||||
        certificate: CertificateKeyPair = self.get_object()
 | 
			
		||||
 | 
			
		||||
@ -9,20 +9,28 @@ from cryptography.hazmat.primitives import hashes, serialization
 | 
			
		||||
from cryptography.hazmat.primitives.asymmetric import ec, rsa
 | 
			
		||||
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
 | 
			
		||||
from cryptography.x509.oid import NameOID
 | 
			
		||||
from django.db import models
 | 
			
		||||
from django.utils.translation import gettext_lazy as _
 | 
			
		||||
 | 
			
		||||
from authentik import __version__
 | 
			
		||||
from authentik.crypto.models import CertificateKeyPair
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PrivateKeyAlg(models.TextChoices):
 | 
			
		||||
    """Algorithm to create private key with"""
 | 
			
		||||
 | 
			
		||||
    RSA = "rsa", _("rsa")
 | 
			
		||||
    ECDSA = "ecdsa", _("ecdsa")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CertificateBuilder:
 | 
			
		||||
    """Build self-signed certificates"""
 | 
			
		||||
 | 
			
		||||
    common_name: str
 | 
			
		||||
    alg: PrivateKeyAlg
 | 
			
		||||
 | 
			
		||||
    _use_ec_private_key: bool
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name: str, use_ec_private_key=False):
 | 
			
		||||
        self._use_ec_private_key = use_ec_private_key
 | 
			
		||||
    def __init__(self, name: str):
 | 
			
		||||
        self.alg = PrivateKeyAlg.RSA
 | 
			
		||||
        self.__public_key = None
 | 
			
		||||
        self.__private_key = None
 | 
			
		||||
        self.__builder = None
 | 
			
		||||
@ -42,11 +50,13 @@ class CertificateBuilder:
 | 
			
		||||
 | 
			
		||||
    def generate_private_key(self) -> PrivateKeyTypes:
 | 
			
		||||
        """Generate private key"""
 | 
			
		||||
        if self._use_ec_private_key:
 | 
			
		||||
        if self.alg == PrivateKeyAlg.ECDSA:
 | 
			
		||||
            return ec.generate_private_key(curve=ec.SECP256R1())
 | 
			
		||||
        return rsa.generate_private_key(
 | 
			
		||||
            public_exponent=65537, key_size=4096, backend=default_backend()
 | 
			
		||||
        )
 | 
			
		||||
        if self.alg == PrivateKeyAlg.RSA:
 | 
			
		||||
            return rsa.generate_private_key(
 | 
			
		||||
                public_exponent=65537, key_size=4096, backend=default_backend()
 | 
			
		||||
            )
 | 
			
		||||
        raise ValueError(f"Invalid alg: {self.alg}")
 | 
			
		||||
 | 
			
		||||
    def build(
 | 
			
		||||
        self,
 | 
			
		||||
 | 
			
		||||
@ -214,6 +214,46 @@ class TestCrypto(APITestCase):
 | 
			
		||||
        self.assertEqual(200, response.status_code)
 | 
			
		||||
        self.assertIn("Content-Disposition", response)
 | 
			
		||||
 | 
			
		||||
    def test_certificate_download_denied(self):
 | 
			
		||||
        """Test certificate export (download)"""
 | 
			
		||||
        self.client.logout()
 | 
			
		||||
        keypair = create_test_cert()
 | 
			
		||||
        response = self.client.get(
 | 
			
		||||
            reverse(
 | 
			
		||||
                "authentik_api:certificatekeypair-view-certificate",
 | 
			
		||||
                kwargs={"pk": keypair.pk},
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(403, response.status_code)
 | 
			
		||||
        response = self.client.get(
 | 
			
		||||
            reverse(
 | 
			
		||||
                "authentik_api:certificatekeypair-view-certificate",
 | 
			
		||||
                kwargs={"pk": keypair.pk},
 | 
			
		||||
            ),
 | 
			
		||||
            data={"download": True},
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(403, response.status_code)
 | 
			
		||||
 | 
			
		||||
    def test_private_key_download_denied(self):
 | 
			
		||||
        """Test private_key export (download)"""
 | 
			
		||||
        self.client.logout()
 | 
			
		||||
        keypair = create_test_cert()
 | 
			
		||||
        response = self.client.get(
 | 
			
		||||
            reverse(
 | 
			
		||||
                "authentik_api:certificatekeypair-view-private-key",
 | 
			
		||||
                kwargs={"pk": keypair.pk},
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(403, response.status_code)
 | 
			
		||||
        response = self.client.get(
 | 
			
		||||
            reverse(
 | 
			
		||||
                "authentik_api:certificatekeypair-view-private-key",
 | 
			
		||||
                kwargs={"pk": keypair.pk},
 | 
			
		||||
            ),
 | 
			
		||||
            data={"download": True},
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(403, response.status_code)
 | 
			
		||||
 | 
			
		||||
    def test_used_by(self):
 | 
			
		||||
        """Test used_by endpoint"""
 | 
			
		||||
        self.client.force_login(create_test_admin_user())
 | 
			
		||||
@ -246,6 +286,26 @@ class TestCrypto(APITestCase):
 | 
			
		||||
            ],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_used_by_denied(self):
 | 
			
		||||
        """Test used_by endpoint"""
 | 
			
		||||
        self.client.logout()
 | 
			
		||||
        keypair = create_test_cert()
 | 
			
		||||
        OAuth2Provider.objects.create(
 | 
			
		||||
            name=generate_id(),
 | 
			
		||||
            client_id="test",
 | 
			
		||||
            client_secret=generate_key(),
 | 
			
		||||
            authorization_flow=create_test_flow(),
 | 
			
		||||
            redirect_uris="http://localhost",
 | 
			
		||||
            signing_key=keypair,
 | 
			
		||||
        )
 | 
			
		||||
        response = self.client.get(
 | 
			
		||||
            reverse(
 | 
			
		||||
                "authentik_api:certificatekeypair-used-by",
 | 
			
		||||
                kwargs={"pk": keypair.pk},
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(403, response.status_code)
 | 
			
		||||
 | 
			
		||||
    def test_discovery(self):
 | 
			
		||||
        """Test certificate discovery"""
 | 
			
		||||
        name = generate_id()
 | 
			
		||||
 | 
			
		||||
@ -2,11 +2,12 @@
 | 
			
		||||
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from functools import partial
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from django.apps.registry import apps
 | 
			
		||||
from django.core.files import File
 | 
			
		||||
from django.db import connection
 | 
			
		||||
from django.db.models import Model
 | 
			
		||||
from django.db.models import ManyToManyRel, Model
 | 
			
		||||
from django.db.models.expressions import BaseExpression, Combinable
 | 
			
		||||
from django.db.models.signals import post_init
 | 
			
		||||
from django.http import HttpRequest
 | 
			
		||||
@ -44,7 +45,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
 | 
			
		||||
        post_init.disconnect(dispatch_uid=request.request_id)
 | 
			
		||||
 | 
			
		||||
    def serialize_simple(self, model: Model) -> dict:
 | 
			
		||||
        """Serialize a model in a very simple way. No ForeginKeys or other relationships are
 | 
			
		||||
        """Serialize a model in a very simple way. No ForeignKeys or other relationships are
 | 
			
		||||
        resolved"""
 | 
			
		||||
        data = {}
 | 
			
		||||
        deferred_fields = model.get_deferred_fields()
 | 
			
		||||
@ -70,6 +71,9 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
 | 
			
		||||
        for key, value in before.items():
 | 
			
		||||
            if after.get(key) != value:
 | 
			
		||||
                diff[key] = {"previous_value": value, "new_value": after.get(key)}
 | 
			
		||||
        for key, value in after.items():
 | 
			
		||||
            if key not in before and key not in diff and before.get(key) != value:
 | 
			
		||||
                diff[key] = {"previous_value": before.get(key), "new_value": value}
 | 
			
		||||
        return sanitize_item(diff)
 | 
			
		||||
 | 
			
		||||
    def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_):
 | 
			
		||||
@ -98,8 +102,37 @@ class EnterpriseAuditMiddleware(AuditMiddleware):
 | 
			
		||||
        thread_kwargs = {}
 | 
			
		||||
        if hasattr(instance, "_previous_state") or created:
 | 
			
		||||
            prev_state = getattr(instance, "_previous_state", {})
 | 
			
		||||
            if created:
 | 
			
		||||
                prev_state = {}
 | 
			
		||||
            # Get current state
 | 
			
		||||
            new_state = self.serialize_simple(instance)
 | 
			
		||||
            diff = self.diff(prev_state, new_state)
 | 
			
		||||
            thread_kwargs["diff"] = diff
 | 
			
		||||
        return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
 | 
			
		||||
 | 
			
		||||
    def m2m_changed_handler(  # noqa: PLR0913
 | 
			
		||||
        self,
 | 
			
		||||
        request: HttpRequest,
 | 
			
		||||
        sender,
 | 
			
		||||
        instance: Model,
 | 
			
		||||
        action: str,
 | 
			
		||||
        pk_set: set[Any],
 | 
			
		||||
        thread_kwargs: dict | None = None,
 | 
			
		||||
        **_,
 | 
			
		||||
    ):
 | 
			
		||||
        thread_kwargs = {}
 | 
			
		||||
        m2m_field = None
 | 
			
		||||
        # For the audit log we don't care about `pre_` or `post_` so we trim that part off
 | 
			
		||||
        _, _, action_direction = action.partition("_")
 | 
			
		||||
        # resolve the "through" model to an actual field
 | 
			
		||||
        for field in instance._meta.get_fields():
 | 
			
		||||
            if not isinstance(field, ManyToManyRel):
 | 
			
		||||
                continue
 | 
			
		||||
            if field.through == sender:
 | 
			
		||||
                m2m_field = field
 | 
			
		||||
        if m2m_field:
 | 
			
		||||
            # If we're clearing we just set the "flag" to True
 | 
			
		||||
            if action_direction == "clear":
 | 
			
		||||
                pk_set = True
 | 
			
		||||
            thread_kwargs["diff"] = {m2m_field.related_name: {action_direction: pk_set}}
 | 
			
		||||
        return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs)
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,22 @@
 | 
			
		||||
from unittest.mock import PropertyMock, patch
 | 
			
		||||
 | 
			
		||||
from django.apps import apps
 | 
			
		||||
from django.conf import settings
 | 
			
		||||
from django.test import TestCase
 | 
			
		||||
from django.urls import reverse
 | 
			
		||||
from rest_framework.test import APITestCase
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import Group, User
 | 
			
		||||
from authentik.core.tests.utils import create_test_admin_user
 | 
			
		||||
from authentik.events.models import Event, EventAction
 | 
			
		||||
from authentik.events.utils import sanitize_item
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestEnterpriseAudit(TestCase):
 | 
			
		||||
class TestEnterpriseAudit(APITestCase):
 | 
			
		||||
    """Test audit middleware"""
 | 
			
		||||
 | 
			
		||||
    def setUp(self) -> None:
 | 
			
		||||
        self.user = create_test_admin_user()
 | 
			
		||||
 | 
			
		||||
    def test_import(self):
 | 
			
		||||
        """Ensure middleware is imported when app.ready is called"""
 | 
			
		||||
@ -16,3 +29,182 @@ class TestEnterpriseAudit(TestCase):
 | 
			
		||||
        self.assertIn(
 | 
			
		||||
            "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware", settings.MIDDLEWARE
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @patch(
 | 
			
		||||
        "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
 | 
			
		||||
        PropertyMock(return_value=True),
 | 
			
		||||
    )
 | 
			
		||||
    def test_create(self):
 | 
			
		||||
        """Test create audit log"""
 | 
			
		||||
        self.client.force_login(self.user)
 | 
			
		||||
        username = generate_id()
 | 
			
		||||
        response = self.client.post(
 | 
			
		||||
            reverse("authentik_api:user-list"),
 | 
			
		||||
            data={"name": generate_id(), "username": username, "groups": [], "path": "foo"},
 | 
			
		||||
        )
 | 
			
		||||
        user = User.objects.get(username=username)
 | 
			
		||||
        self.assertEqual(response.status_code, 201)
 | 
			
		||||
        events = Event.objects.filter(
 | 
			
		||||
            action=EventAction.MODEL_CREATED,
 | 
			
		||||
            context__model__model_name="user",
 | 
			
		||||
            context__model__app="authentik_core",
 | 
			
		||||
            context__model__pk=user.pk,
 | 
			
		||||
        )
 | 
			
		||||
        event = events.first()
 | 
			
		||||
        self.assertIsNotNone(event)
 | 
			
		||||
        self.assertIsNotNone(event.context["diff"])
 | 
			
		||||
        diff = event.context["diff"]
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            diff,
 | 
			
		||||
            {
 | 
			
		||||
                "name": {
 | 
			
		||||
                    "new_value": user.name,
 | 
			
		||||
                    "previous_value": None,
 | 
			
		||||
                },
 | 
			
		||||
                "path": {"new_value": "foo", "previous_value": None},
 | 
			
		||||
                "type": {"new_value": "internal", "previous_value": None},
 | 
			
		||||
                "uuid": {
 | 
			
		||||
                    "new_value": user.uuid.hex,
 | 
			
		||||
                    "previous_value": None,
 | 
			
		||||
                },
 | 
			
		||||
                "email": {"new_value": "", "previous_value": None},
 | 
			
		||||
                "username": {
 | 
			
		||||
                    "new_value": user.username,
 | 
			
		||||
                    "previous_value": None,
 | 
			
		||||
                },
 | 
			
		||||
                "is_active": {"new_value": True, "previous_value": None},
 | 
			
		||||
                "attributes": {"new_value": {}, "previous_value": None},
 | 
			
		||||
                "date_joined": {
 | 
			
		||||
                    "new_value": sanitize_item(user.date_joined),
 | 
			
		||||
                    "previous_value": None,
 | 
			
		||||
                },
 | 
			
		||||
                "first_name": {"new_value": "", "previous_value": None},
 | 
			
		||||
                "id": {"new_value": user.pk, "previous_value": None},
 | 
			
		||||
                "last_name": {"new_value": "", "previous_value": None},
 | 
			
		||||
                "password": {"new_value": "********************", "previous_value": None},
 | 
			
		||||
                "password_change_date": {
 | 
			
		||||
                    "new_value": sanitize_item(user.password_change_date),
 | 
			
		||||
                    "previous_value": None,
 | 
			
		||||
                },
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @patch(
 | 
			
		||||
        "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
 | 
			
		||||
        PropertyMock(return_value=True),
 | 
			
		||||
    )
 | 
			
		||||
    def test_update(self):
 | 
			
		||||
        """Test update audit log"""
 | 
			
		||||
        self.client.force_login(self.user)
 | 
			
		||||
        user = create_test_admin_user()
 | 
			
		||||
        current_name = user.name
 | 
			
		||||
        new_name = generate_id()
 | 
			
		||||
        response = self.client.patch(
 | 
			
		||||
            reverse("authentik_api:user-detail", kwargs={"pk": user.id}),
 | 
			
		||||
            data={"name": new_name},
 | 
			
		||||
        )
 | 
			
		||||
        user.refresh_from_db()
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
        events = Event.objects.filter(
 | 
			
		||||
            action=EventAction.MODEL_UPDATED,
 | 
			
		||||
            context__model__model_name="user",
 | 
			
		||||
            context__model__app="authentik_core",
 | 
			
		||||
            context__model__pk=user.pk,
 | 
			
		||||
        )
 | 
			
		||||
        event = events.first()
 | 
			
		||||
        self.assertIsNotNone(event)
 | 
			
		||||
        self.assertIsNotNone(event.context["diff"])
 | 
			
		||||
        diff = event.context["diff"]
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            diff,
 | 
			
		||||
            {
 | 
			
		||||
                "name": {
 | 
			
		||||
                    "new_value": new_name,
 | 
			
		||||
                    "previous_value": current_name,
 | 
			
		||||
                },
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @patch(
 | 
			
		||||
        "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
 | 
			
		||||
        PropertyMock(return_value=True),
 | 
			
		||||
    )
 | 
			
		||||
    def test_delete(self):
 | 
			
		||||
        """Test delete audit log"""
 | 
			
		||||
        self.client.force_login(self.user)
 | 
			
		||||
        user = create_test_admin_user()
 | 
			
		||||
        response = self.client.delete(
 | 
			
		||||
            reverse("authentik_api:user-detail", kwargs={"pk": user.id}),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 204)
 | 
			
		||||
        events = Event.objects.filter(
 | 
			
		||||
            action=EventAction.MODEL_DELETED,
 | 
			
		||||
            context__model__model_name="user",
 | 
			
		||||
            context__model__app="authentik_core",
 | 
			
		||||
            context__model__pk=user.pk,
 | 
			
		||||
        )
 | 
			
		||||
        event = events.first()
 | 
			
		||||
        self.assertIsNotNone(event)
 | 
			
		||||
        self.assertNotIn("diff", event.context)
 | 
			
		||||
 | 
			
		||||
    @patch(
 | 
			
		||||
        "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
 | 
			
		||||
        PropertyMock(return_value=True),
 | 
			
		||||
    )
 | 
			
		||||
    def test_m2m_add(self):
 | 
			
		||||
        """Test m2m add audit log"""
 | 
			
		||||
        self.client.force_login(self.user)
 | 
			
		||||
        user = create_test_admin_user()
 | 
			
		||||
        group = Group.objects.create(name=generate_id())
 | 
			
		||||
        response = self.client.post(
 | 
			
		||||
            reverse("authentik_api:group-add-user", kwargs={"pk": group.group_uuid}),
 | 
			
		||||
            data={
 | 
			
		||||
                "pk": user.pk,
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 204)
 | 
			
		||||
        events = Event.objects.filter(
 | 
			
		||||
            action=EventAction.MODEL_UPDATED,
 | 
			
		||||
            context__model__model_name="group",
 | 
			
		||||
            context__model__app="authentik_core",
 | 
			
		||||
            context__model__pk=group.pk.hex,
 | 
			
		||||
        )
 | 
			
		||||
        event = events.first()
 | 
			
		||||
        self.assertIsNotNone(event)
 | 
			
		||||
        self.assertIsNotNone(event.context["diff"])
 | 
			
		||||
        diff = event.context["diff"]
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            diff,
 | 
			
		||||
            {"users": {"add": [user.pk]}},
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @patch(
 | 
			
		||||
        "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled",
 | 
			
		||||
        PropertyMock(return_value=True),
 | 
			
		||||
    )
 | 
			
		||||
    def test_m2m_remove(self):
 | 
			
		||||
        """Test m2m remove audit log"""
 | 
			
		||||
        self.client.force_login(self.user)
 | 
			
		||||
        user = create_test_admin_user()
 | 
			
		||||
        group = Group.objects.create(name=generate_id())
 | 
			
		||||
        response = self.client.post(
 | 
			
		||||
            reverse("authentik_api:group-remove-user", kwargs={"pk": group.group_uuid}),
 | 
			
		||||
            data={
 | 
			
		||||
                "pk": user.pk,
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 204)
 | 
			
		||||
        events = Event.objects.filter(
 | 
			
		||||
            action=EventAction.MODEL_UPDATED,
 | 
			
		||||
            context__model__model_name="group",
 | 
			
		||||
            context__model__app="authentik_core",
 | 
			
		||||
            context__model__pk=group.pk.hex,
 | 
			
		||||
        )
 | 
			
		||||
        event = events.first()
 | 
			
		||||
        self.assertIsNotNone(event)
 | 
			
		||||
        self.assertIsNotNone(event.context["diff"])
 | 
			
		||||
        diff = event.context["diff"]
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            diff,
 | 
			
		||||
            {"users": {"remove": [user.pk]}},
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -201,10 +201,7 @@ class ConnectionToken(ExpiringModel):
 | 
			
		||||
        return settings
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return (
 | 
			
		||||
            f"RAC Connection token {self.session.user} to "
 | 
			
		||||
            f"{self.endpoint.provider.name}/{self.endpoint.name}"
 | 
			
		||||
        )
 | 
			
		||||
        return f"RAC Connection token {self.session_id} to {self.provider_id}/{self.endpoint_id}"
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        verbose_name = _("RAC Connection token")
 | 
			
		||||
 | 
			
		||||
@ -116,12 +116,12 @@ class AuditMiddleware:
 | 
			
		||||
            return user
 | 
			
		||||
        user = getattr(request, "user", self.anonymous_user)
 | 
			
		||||
        if not user.is_authenticated:
 | 
			
		||||
            self._ensure_fallback_user()
 | 
			
		||||
            return self.anonymous_user
 | 
			
		||||
        return user
 | 
			
		||||
 | 
			
		||||
    def connect(self, request: HttpRequest):
 | 
			
		||||
        """Connect signal for automatic logging"""
 | 
			
		||||
        self._ensure_fallback_user()
 | 
			
		||||
        if not hasattr(request, "request_id"):
 | 
			
		||||
            return
 | 
			
		||||
        post_save.connect(
 | 
			
		||||
@ -214,7 +214,15 @@ class AuditMiddleware:
 | 
			
		||||
            model=model_to_dict(instance),
 | 
			
		||||
        ).run()
 | 
			
		||||
 | 
			
		||||
    def m2m_changed_handler(self, request: HttpRequest, sender, instance: Model, action: str, **_):
 | 
			
		||||
    def m2m_changed_handler(
 | 
			
		||||
        self,
 | 
			
		||||
        request: HttpRequest,
 | 
			
		||||
        sender,
 | 
			
		||||
        instance: Model,
 | 
			
		||||
        action: str,
 | 
			
		||||
        thread_kwargs: dict | None = None,
 | 
			
		||||
        **_,
 | 
			
		||||
    ):
 | 
			
		||||
        """Signal handler for all object's m2m_changed"""
 | 
			
		||||
        if action not in ["pre_add", "pre_remove", "post_clear"]:
 | 
			
		||||
            return
 | 
			
		||||
@ -229,4 +237,5 @@ class AuditMiddleware:
 | 
			
		||||
            request,
 | 
			
		||||
            user=user,
 | 
			
		||||
            model=model_to_dict(instance),
 | 
			
		||||
            **thread_kwargs,
 | 
			
		||||
        ).run()
 | 
			
		||||
 | 
			
		||||
@ -556,7 +556,7 @@ class Notification(SerializerModel):
 | 
			
		||||
            if len(self.body) > NOTIFICATION_SUMMARY_LENGTH
 | 
			
		||||
            else self.body
 | 
			
		||||
        )
 | 
			
		||||
        return f"Notification for user {self.user}: {body_trunc}"
 | 
			
		||||
        return f"Notification for user {self.user_id}: {body_trunc}"
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        verbose_name = _("Notification")
 | 
			
		||||
 | 
			
		||||
@ -119,7 +119,7 @@ class SystemTask(TenantTask):
 | 
			
		||||
                "task_call_kwargs": sanitize_item(kwargs),
 | 
			
		||||
                "status": self._status,
 | 
			
		||||
                "messages": sanitize_item(self._messages),
 | 
			
		||||
                "expires": now() + timedelta(hours=self.result_timeout_hours),
 | 
			
		||||
                "expires": now() + timedelta(hours=self.result_timeout_hours + 3),
 | 
			
		||||
                "expiring": True,
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										35
									
								
								authentik/events/tests/test_models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								authentik/events/tests/test_models.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,35 @@
 | 
			
		||||
"""authentik event models tests"""
 | 
			
		||||
 | 
			
		||||
from collections.abc import Callable
 | 
			
		||||
 | 
			
		||||
from django.db.models import Model
 | 
			
		||||
from django.test import TestCase
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import default_token_key
 | 
			
		||||
from authentik.lib.utils.reflection import get_apps
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestModels(TestCase):
 | 
			
		||||
    """Test Models"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def model_tester_factory(test_model: type[Model]) -> Callable:
 | 
			
		||||
    """Test models' __str__ and __repr__"""
 | 
			
		||||
 | 
			
		||||
    def tester(self: TestModels):
 | 
			
		||||
        allowed = 0
 | 
			
		||||
        # Token-like objects need to lookup the current tenant to get the default token length
 | 
			
		||||
        for field in test_model._meta.fields:
 | 
			
		||||
            if field.default == default_token_key:
 | 
			
		||||
                allowed += 1
 | 
			
		||||
        with self.assertNumQueries(allowed):
 | 
			
		||||
            str(test_model())
 | 
			
		||||
        with self.assertNumQueries(allowed):
 | 
			
		||||
            repr(test_model())
 | 
			
		||||
 | 
			
		||||
    return tester
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
for app in get_apps():
 | 
			
		||||
    for model in app.get_models():
 | 
			
		||||
        setattr(TestModels, f"test_{app.label}_{model.__name__}", model_tester_factory(model))
 | 
			
		||||
@ -33,6 +33,7 @@ from authentik.lib.utils.file import (
 | 
			
		||||
)
 | 
			
		||||
from authentik.lib.views import bad_request_message
 | 
			
		||||
from authentik.rbac.decorators import permission_required
 | 
			
		||||
from authentik.rbac.filters import ObjectFilter
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
@ -277,8 +278,8 @@ class FlowViewSet(UsedByMixin, ModelViewSet):
 | 
			
		||||
            400: OpenApiResponse(description="Flow not applicable"),
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
    @action(detail=True, pagination_class=None, filter_backends=[])
 | 
			
		||||
    def execute(self, request: Request, _slug: str):
 | 
			
		||||
    @action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
 | 
			
		||||
    def execute(self, request: Request, slug: str):
 | 
			
		||||
        """Execute flow for current user"""
 | 
			
		||||
        # Because we pre-plan the flow here, and not in the planner, we need to manually clear
 | 
			
		||||
        # the history of the inspector
 | 
			
		||||
 | 
			
		||||
@ -203,7 +203,8 @@ class FlowPlanner:
 | 
			
		||||
                "f(plan): building plan",
 | 
			
		||||
            )
 | 
			
		||||
            plan = self._build_plan(user, request, default_context)
 | 
			
		||||
            cache.set(cache_key(self.flow, user), plan, CACHE_TIMEOUT)
 | 
			
		||||
            if self.use_cache:
 | 
			
		||||
                cache.set(cache_key(self.flow, user), plan, CACHE_TIMEOUT)
 | 
			
		||||
            if not plan.bindings and not self.allow_empty_flows:
 | 
			
		||||
                raise EmptyFlowException()
 | 
			
		||||
            return plan
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,7 @@ from rest_framework.test import APITestCase
 | 
			
		||||
from authentik.core.tests.utils import create_test_admin_user
 | 
			
		||||
from authentik.flows.api.stages import StageSerializer, StageViewSet
 | 
			
		||||
from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding, Stage
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
from authentik.policies.dummy.models import DummyPolicy
 | 
			
		||||
from authentik.policies.models import PolicyBinding
 | 
			
		||||
from authentik.stages.dummy.models import DummyStage
 | 
			
		||||
@ -101,3 +102,21 @@ class TestFlowsAPI(APITestCase):
 | 
			
		||||
            reverse("authentik_api:stage-types"),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
    def test_execute(self):
 | 
			
		||||
        """Test execute endpoint"""
 | 
			
		||||
        user = create_test_admin_user()
 | 
			
		||||
        self.client.force_login(user)
 | 
			
		||||
 | 
			
		||||
        flow = Flow.objects.create(
 | 
			
		||||
            name=generate_id(),
 | 
			
		||||
            slug=generate_id(),
 | 
			
		||||
            designation=FlowDesignation.AUTHENTICATION,
 | 
			
		||||
        )
 | 
			
		||||
        FlowStageBinding.objects.create(
 | 
			
		||||
            target=flow, stage=DummyStage.objects.create(name=generate_id()), order=0
 | 
			
		||||
        )
 | 
			
		||||
        response = self.client.get(
 | 
			
		||||
            reverse("authentik_api:flow-execute", kwargs={"slug": flow.slug})
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(response.status_code, 200)
 | 
			
		||||
 | 
			
		||||
@ -53,6 +53,7 @@ cache:
 | 
			
		||||
 | 
			
		||||
# result_backend:
 | 
			
		||||
#   url: ""
 | 
			
		||||
#   transport_options: ""
 | 
			
		||||
 | 
			
		||||
debug: false
 | 
			
		||||
remote_debug: false
 | 
			
		||||
 | 
			
		||||
@ -23,6 +23,7 @@ from authentik.outposts.models import (
 | 
			
		||||
    KubernetesServiceConnection,
 | 
			
		||||
    OutpostServiceConnection,
 | 
			
		||||
)
 | 
			
		||||
from authentik.rbac.filters import ObjectFilter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ServiceConnectionSerializer(ModelSerializer, MetaNameSerializer):
 | 
			
		||||
@ -88,7 +89,7 @@ class ServiceConnectionViewSet(
 | 
			
		||||
        return Response(TypeCreateSerializer(data, many=True).data)
 | 
			
		||||
 | 
			
		||||
    @extend_schema(responses={200: ServiceConnectionStateSerializer(many=False)})
 | 
			
		||||
    @action(detail=True, pagination_class=None, filter_backends=[])
 | 
			
		||||
    @action(detail=True, pagination_class=None, filter_backends=[ObjectFilter])
 | 
			
		||||
    def state(self, request: Request, pk: str) -> Response:
 | 
			
		||||
        """Get the service connection's state"""
 | 
			
		||||
        connection = self.get_object()
 | 
			
		||||
 | 
			
		||||
@ -326,7 +326,7 @@ class AuthorizationCode(SerializerModel, ExpiringModel, BaseGrantModel):
 | 
			
		||||
        verbose_name_plural = _("Authorization Codes")
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f"Authorization code for {self.provider} for user {self.user}"
 | 
			
		||||
        return f"Authorization code for {self.provider_id} for user {self.user_id}"
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def serializer(self) -> Serializer:
 | 
			
		||||
@ -356,7 +356,7 @@ class AccessToken(SerializerModel, ExpiringModel, BaseGrantModel):
 | 
			
		||||
        verbose_name_plural = _("OAuth2 Access Tokens")
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f"Access Token for {self.provider} for user {self.user}"
 | 
			
		||||
        return f"Access Token for {self.provider_id} for user {self.user_id}"
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def id_token(self) -> IDToken:
 | 
			
		||||
@ -399,7 +399,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
 | 
			
		||||
        verbose_name_plural = _("OAuth2 Refresh Tokens")
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f"Refresh Token for {self.provider} for user {self.user}"
 | 
			
		||||
        return f"Refresh Token for {self.provider_id} for user {self.user_id}"
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def id_token(self) -> IDToken:
 | 
			
		||||
@ -443,4 +443,4 @@ class DeviceToken(ExpiringModel):
 | 
			
		||||
        verbose_name_plural = _("Device Tokens")
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f"Device Token for {self.provider}"
 | 
			
		||||
        return f"Device Token for {self.provider_id}"
 | 
			
		||||
 | 
			
		||||
@ -4,9 +4,10 @@ from urllib.parse import urlencode
 | 
			
		||||
 | 
			
		||||
from django.urls import reverse
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import Application
 | 
			
		||||
from authentik.core.models import Application, Group
 | 
			
		||||
from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
from authentik.policies.models import PolicyBinding
 | 
			
		||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
 | 
			
		||||
from authentik.providers.oauth2.tests.utils import OAuthTestCase
 | 
			
		||||
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
 | 
			
		||||
@ -77,3 +78,23 @@ class TesOAuth2DeviceInit(OAuthTestCase):
 | 
			
		||||
            + "?"
 | 
			
		||||
            + urlencode({QS_KEY_CODE: token.user_code}),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_device_init_denied(self):
 | 
			
		||||
        """Test device init"""
 | 
			
		||||
        group = Group.objects.create(name="foo")
 | 
			
		||||
        PolicyBinding.objects.create(
 | 
			
		||||
            group=group,
 | 
			
		||||
            target=self.application,
 | 
			
		||||
            order=0,
 | 
			
		||||
        )
 | 
			
		||||
        token = DeviceToken.objects.create(
 | 
			
		||||
            user_code="foo",
 | 
			
		||||
            provider=self.provider,
 | 
			
		||||
        )
 | 
			
		||||
        res = self.client.get(
 | 
			
		||||
            reverse("authentik_providers_oauth2_root:device-login")
 | 
			
		||||
            + "?"
 | 
			
		||||
            + urlencode({QS_KEY_CODE: token.user_code})
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(res.status_code, 200)
 | 
			
		||||
        self.assertIn(b"Permission denied", res.content)
 | 
			
		||||
 | 
			
		||||
@ -10,6 +10,7 @@ from jwt import PyJWKSet
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import Application
 | 
			
		||||
from authentik.core.tests.utils import create_test_cert, create_test_flow
 | 
			
		||||
from authentik.crypto.builder import PrivateKeyAlg
 | 
			
		||||
from authentik.crypto.models import CertificateKeyPair
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
from authentik.providers.oauth2.models import OAuth2Provider
 | 
			
		||||
@ -82,7 +83,7 @@ class TestJWKS(OAuthTestCase):
 | 
			
		||||
            client_id="test",
 | 
			
		||||
            authorization_flow=create_test_flow(),
 | 
			
		||||
            redirect_uris="http://local.invalid",
 | 
			
		||||
            signing_key=create_test_cert(use_ec_private_key=True),
 | 
			
		||||
            signing_key=create_test_cert(PrivateKeyAlg.ECDSA),
 | 
			
		||||
        )
 | 
			
		||||
        app = Application.objects.create(name="test", slug="test", provider=provider)
 | 
			
		||||
        response = self.client.get(
 | 
			
		||||
 | 
			
		||||
@ -11,10 +11,11 @@ from django.views.decorators.csrf import csrf_exempt
 | 
			
		||||
from rest_framework.throttling import AnonRateThrottle
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import Application
 | 
			
		||||
from authentik.lib.config import CONFIG
 | 
			
		||||
from authentik.lib.utils.time import timedelta_from_string
 | 
			
		||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
 | 
			
		||||
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE, get_application
 | 
			
		||||
from authentik.providers.oauth2.views.device_init import QS_KEY_CODE
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
@ -37,7 +38,9 @@ class DeviceView(View):
 | 
			
		||||
        ).first()
 | 
			
		||||
        if not provider:
 | 
			
		||||
            return HttpResponseBadRequest()
 | 
			
		||||
        if not get_application(provider):
 | 
			
		||||
        try:
 | 
			
		||||
            _ = provider.application
 | 
			
		||||
        except Application.DoesNotExist:
 | 
			
		||||
            return HttpResponseBadRequest()
 | 
			
		||||
        self.provider = provider
 | 
			
		||||
        self.client_id = client_id
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,9 @@
 | 
			
		||||
"""Device flow views"""
 | 
			
		||||
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from django.http import HttpRequest, HttpResponse
 | 
			
		||||
from django.utils.translation import gettext as _
 | 
			
		||||
from django.views import View
 | 
			
		||||
from rest_framework.exceptions import ValidationError
 | 
			
		||||
from rest_framework.fields import CharField, IntegerField
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
@ -16,7 +17,8 @@ from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO,
 | 
			
		||||
from authentik.flows.stage import ChallengeStageView
 | 
			
		||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
 | 
			
		||||
from authentik.lib.utils.urls import redirect_with_qs
 | 
			
		||||
from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider
 | 
			
		||||
from authentik.policies.views import PolicyAccessView
 | 
			
		||||
from authentik.providers.oauth2.models import DeviceToken
 | 
			
		||||
from authentik.providers.oauth2.views.device_finish import (
 | 
			
		||||
    PLAN_CONTEXT_DEVICE,
 | 
			
		||||
    OAuthDeviceCodeFinishStage,
 | 
			
		||||
@ -31,60 +33,52 @@ LOGGER = get_logger()
 | 
			
		||||
QS_KEY_CODE = "code"  # nosec
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_application(provider: OAuth2Provider) -> Application | None:
 | 
			
		||||
    """Get application from provider"""
 | 
			
		||||
    try:
 | 
			
		||||
        app = provider.application
 | 
			
		||||
        if not app:
 | 
			
		||||
class CodeValidatorView(PolicyAccessView):
 | 
			
		||||
    """Helper to validate frontside token"""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, code: str, **kwargs: Any) -> None:
 | 
			
		||||
        super().__init__(**kwargs)
 | 
			
		||||
        self.code = code
 | 
			
		||||
 | 
			
		||||
    def resolve_provider_application(self):
 | 
			
		||||
        self.token = DeviceToken.objects.filter(user_code=self.code).first()
 | 
			
		||||
        if not self.token:
 | 
			
		||||
            raise Application.DoesNotExist
 | 
			
		||||
        self.provider = self.token.provider
 | 
			
		||||
        self.application = self.token.provider.application
 | 
			
		||||
 | 
			
		||||
    def get(self, request: HttpRequest, *args, **kwargs):
 | 
			
		||||
        scope_descriptions = UserInfoView().get_scope_descriptions(self.token.scope, self.provider)
 | 
			
		||||
        planner = FlowPlanner(self.provider.authorization_flow)
 | 
			
		||||
        planner.allow_empty_flows = True
 | 
			
		||||
        planner.use_cache = False
 | 
			
		||||
        try:
 | 
			
		||||
            plan = planner.plan(
 | 
			
		||||
                request,
 | 
			
		||||
                {
 | 
			
		||||
                    PLAN_CONTEXT_SSO: True,
 | 
			
		||||
                    PLAN_CONTEXT_APPLICATION: self.application,
 | 
			
		||||
                    # OAuth2 related params
 | 
			
		||||
                    PLAN_CONTEXT_DEVICE: self.token,
 | 
			
		||||
                    # Consent related params
 | 
			
		||||
                    PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.")
 | 
			
		||||
                    % {"application": self.application.name},
 | 
			
		||||
                    PLAN_CONTEXT_CONSENT_PERMISSIONS: scope_descriptions,
 | 
			
		||||
                },
 | 
			
		||||
            )
 | 
			
		||||
        except FlowNonApplicableException:
 | 
			
		||||
            LOGGER.warning("Flow not applicable to user")
 | 
			
		||||
            return None
 | 
			
		||||
        return app
 | 
			
		||||
    except Application.DoesNotExist:
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def validate_code(code: int, request: HttpRequest) -> HttpResponse | None:
 | 
			
		||||
    """Validate user token"""
 | 
			
		||||
    token = DeviceToken.objects.filter(
 | 
			
		||||
        user_code=code,
 | 
			
		||||
    ).first()
 | 
			
		||||
    if not token:
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    app = get_application(token.provider)
 | 
			
		||||
    if not app:
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    scope_descriptions = UserInfoView().get_scope_descriptions(token.scope, token.provider)
 | 
			
		||||
    planner = FlowPlanner(token.provider.authorization_flow)
 | 
			
		||||
    planner.allow_empty_flows = True
 | 
			
		||||
    planner.use_cache = False
 | 
			
		||||
    try:
 | 
			
		||||
        plan = planner.plan(
 | 
			
		||||
            request,
 | 
			
		||||
            {
 | 
			
		||||
                PLAN_CONTEXT_SSO: True,
 | 
			
		||||
                PLAN_CONTEXT_APPLICATION: app,
 | 
			
		||||
                # OAuth2 related params
 | 
			
		||||
                PLAN_CONTEXT_DEVICE: token,
 | 
			
		||||
                # Consent related params
 | 
			
		||||
                PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.")
 | 
			
		||||
                % {"application": app.name},
 | 
			
		||||
                PLAN_CONTEXT_CONSENT_PERMISSIONS: scope_descriptions,
 | 
			
		||||
            },
 | 
			
		||||
        plan.insert_stage(in_memory_stage(OAuthDeviceCodeFinishStage))
 | 
			
		||||
        request.session[SESSION_KEY_PLAN] = plan
 | 
			
		||||
        return redirect_with_qs(
 | 
			
		||||
            "authentik_core:if-flow",
 | 
			
		||||
            request.GET,
 | 
			
		||||
            flow_slug=self.token.provider.authorization_flow.slug,
 | 
			
		||||
        )
 | 
			
		||||
    except FlowNonApplicableException:
 | 
			
		||||
        LOGGER.warning("Flow not applicable to user")
 | 
			
		||||
        return None
 | 
			
		||||
    plan.insert_stage(in_memory_stage(OAuthDeviceCodeFinishStage))
 | 
			
		||||
    request.session[SESSION_KEY_PLAN] = plan
 | 
			
		||||
    return redirect_with_qs(
 | 
			
		||||
        "authentik_core:if-flow",
 | 
			
		||||
        request.GET,
 | 
			
		||||
        flow_slug=token.provider.authorization_flow.slug,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeviceEntryView(View):
 | 
			
		||||
class DeviceEntryView(PolicyAccessView):
 | 
			
		||||
    """View used to initiate the device-code flow, url entered by endusers"""
 | 
			
		||||
 | 
			
		||||
    def dispatch(self, request: HttpRequest) -> HttpResponse:
 | 
			
		||||
@ -94,7 +88,9 @@ class DeviceEntryView(View):
 | 
			
		||||
            LOGGER.info("Brand has no device code flow configured", brand=brand)
 | 
			
		||||
            return HttpResponse(status=404)
 | 
			
		||||
        if QS_KEY_CODE in request.GET:
 | 
			
		||||
            validation = validate_code(request.GET[QS_KEY_CODE], request)
 | 
			
		||||
            validation = CodeValidatorView(request.GET[QS_KEY_CODE], request=request).dispatch(
 | 
			
		||||
                request
 | 
			
		||||
            )
 | 
			
		||||
            if validation:
 | 
			
		||||
                return validation
 | 
			
		||||
            LOGGER.info("Got code from query parameter but no matching token found")
 | 
			
		||||
@ -131,7 +127,7 @@ class OAuthDeviceCodeChallengeResponse(ChallengeResponse):
 | 
			
		||||
 | 
			
		||||
    def validate_code(self, code: int) -> HttpResponse | None:
 | 
			
		||||
        """Validate code and save the returned http response"""
 | 
			
		||||
        response = validate_code(code, self.stage.request)
 | 
			
		||||
        response = CodeValidatorView(code, request=self.stage.request).dispatch(self.stage.request)
 | 
			
		||||
        if not response:
 | 
			
		||||
            raise ValidationError(_("Invalid code"), "invalid")
 | 
			
		||||
        return response
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,44 @@
 | 
			
		||||
# Generated by Django 5.0.4 on 2024-05-01 15:32
 | 
			
		||||
 | 
			
		||||
from django.db import migrations, models
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Migration(migrations.Migration):
 | 
			
		||||
 | 
			
		||||
    dependencies = [
 | 
			
		||||
        ("authentik_providers_saml", "0013_samlprovider_default_relay_state"),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    operations = [
 | 
			
		||||
        migrations.AlterField(
 | 
			
		||||
            model_name="samlprovider",
 | 
			
		||||
            name="digest_algorithm",
 | 
			
		||||
            field=models.TextField(
 | 
			
		||||
                choices=[
 | 
			
		||||
                    ("http://www.w3.org/2000/09/xmldsig#sha1", "SHA1"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmlenc#sha256", "SHA256"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmldsig-more#sha384", "SHA384"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmlenc#sha512", "SHA512"),
 | 
			
		||||
                ],
 | 
			
		||||
                default="http://www.w3.org/2001/04/xmlenc#sha256",
 | 
			
		||||
            ),
 | 
			
		||||
        ),
 | 
			
		||||
        migrations.AlterField(
 | 
			
		||||
            model_name="samlprovider",
 | 
			
		||||
            name="signature_algorithm",
 | 
			
		||||
            field=models.TextField(
 | 
			
		||||
                choices=[
 | 
			
		||||
                    ("http://www.w3.org/2000/09/xmldsig#rsa-sha1", "RSA-SHA1"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmldsig-more#rsa-sha256", "RSA-SHA256"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmldsig-more#rsa-sha384", "RSA-SHA384"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmldsig-more#rsa-sha512", "RSA-SHA512"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1", "ECDSA-SHA1"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256", "ECDSA-SHA256"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384", "ECDSA-SHA384"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512", "ECDSA-SHA512"),
 | 
			
		||||
                    ("http://www.w3.org/2000/09/xmldsig#dsa-sha1", "DSA-SHA1"),
 | 
			
		||||
                ],
 | 
			
		||||
                default="http://www.w3.org/2001/04/xmldsig-more#rsa-sha256",
 | 
			
		||||
            ),
 | 
			
		||||
        ),
 | 
			
		||||
    ]
 | 
			
		||||
@ -11,6 +11,10 @@ from authentik.crypto.models import CertificateKeyPair
 | 
			
		||||
from authentik.lib.utils.time import timedelta_string_validator
 | 
			
		||||
from authentik.sources.saml.processors.constants import (
 | 
			
		||||
    DSA_SHA1,
 | 
			
		||||
    ECDSA_SHA1,
 | 
			
		||||
    ECDSA_SHA256,
 | 
			
		||||
    ECDSA_SHA384,
 | 
			
		||||
    ECDSA_SHA512,
 | 
			
		||||
    RSA_SHA1,
 | 
			
		||||
    RSA_SHA256,
 | 
			
		||||
    RSA_SHA384,
 | 
			
		||||
@ -92,8 +96,7 @@ class SAMLProvider(Provider):
 | 
			
		||||
        ),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    digest_algorithm = models.CharField(
 | 
			
		||||
        max_length=50,
 | 
			
		||||
    digest_algorithm = models.TextField(
 | 
			
		||||
        choices=(
 | 
			
		||||
            (SHA1, _("SHA1")),
 | 
			
		||||
            (SHA256, _("SHA256")),
 | 
			
		||||
@ -102,13 +105,16 @@ class SAMLProvider(Provider):
 | 
			
		||||
        ),
 | 
			
		||||
        default=SHA256,
 | 
			
		||||
    )
 | 
			
		||||
    signature_algorithm = models.CharField(
 | 
			
		||||
        max_length=50,
 | 
			
		||||
    signature_algorithm = models.TextField(
 | 
			
		||||
        choices=(
 | 
			
		||||
            (RSA_SHA1, _("RSA-SHA1")),
 | 
			
		||||
            (RSA_SHA256, _("RSA-SHA256")),
 | 
			
		||||
            (RSA_SHA384, _("RSA-SHA384")),
 | 
			
		||||
            (RSA_SHA512, _("RSA-SHA512")),
 | 
			
		||||
            (ECDSA_SHA1, _("ECDSA-SHA1")),
 | 
			
		||||
            (ECDSA_SHA256, _("ECDSA-SHA256")),
 | 
			
		||||
            (ECDSA_SHA384, _("ECDSA-SHA384")),
 | 
			
		||||
            (ECDSA_SHA512, _("ECDSA-SHA512")),
 | 
			
		||||
            (DSA_SHA1, _("DSA-SHA1")),
 | 
			
		||||
        ),
 | 
			
		||||
        default=RSA_SHA256,
 | 
			
		||||
 | 
			
		||||
@ -7,13 +7,14 @@ from lxml import etree  # nosec
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import Application
 | 
			
		||||
from authentik.core.tests.utils import create_test_cert, create_test_flow
 | 
			
		||||
from authentik.crypto.builder import PrivateKeyAlg
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
from authentik.lib.tests.utils import load_fixture
 | 
			
		||||
from authentik.lib.xml import lxml_from_string
 | 
			
		||||
from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping, SAMLProvider
 | 
			
		||||
from authentik.providers.saml.processors.metadata import MetadataProcessor
 | 
			
		||||
from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser
 | 
			
		||||
from authentik.sources.saml.processors.constants import NS_MAP, NS_SAML_METADATA
 | 
			
		||||
from authentik.sources.saml.processors.constants import ECDSA_SHA256, NS_MAP, NS_SAML_METADATA
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestServiceProviderMetadataParser(TestCase):
 | 
			
		||||
@ -107,12 +108,41 @@ class TestServiceProviderMetadataParser(TestCase):
 | 
			
		||||
                load_fixture("fixtures/cert.xml").replace("/apps/user_saml", "")
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_signature(self):
 | 
			
		||||
        """Test signature validation"""
 | 
			
		||||
    def test_signature_rsa(self):
 | 
			
		||||
        """Test signature validation (RSA)"""
 | 
			
		||||
        provider = SAMLProvider.objects.create(
 | 
			
		||||
            name=generate_id(),
 | 
			
		||||
            authorization_flow=self.flow,
 | 
			
		||||
            signing_kp=create_test_cert(),
 | 
			
		||||
            signing_kp=create_test_cert(PrivateKeyAlg.RSA),
 | 
			
		||||
        )
 | 
			
		||||
        Application.objects.create(
 | 
			
		||||
            name=generate_id(),
 | 
			
		||||
            slug=generate_id(),
 | 
			
		||||
            provider=provider,
 | 
			
		||||
        )
 | 
			
		||||
        request = self.factory.get("/")
 | 
			
		||||
        metadata = MetadataProcessor(provider, request).build_entity_descriptor()
 | 
			
		||||
 | 
			
		||||
        root = fromstring(metadata.encode())
 | 
			
		||||
        xmlsec.tree.add_ids(root, ["ID"])
 | 
			
		||||
        signature_nodes = root.xpath("/md:EntityDescriptor/ds:Signature", namespaces=NS_MAP)
 | 
			
		||||
        signature_node = signature_nodes[0]
 | 
			
		||||
        ctx = xmlsec.SignatureContext()
 | 
			
		||||
        key = xmlsec.Key.from_memory(
 | 
			
		||||
            provider.signing_kp.certificate_data,
 | 
			
		||||
            xmlsec.constants.KeyDataFormatCertPem,
 | 
			
		||||
            None,
 | 
			
		||||
        )
 | 
			
		||||
        ctx.key = key
 | 
			
		||||
        ctx.verify(signature_node)
 | 
			
		||||
 | 
			
		||||
    def test_signature_ecdsa(self):
 | 
			
		||||
        """Test signature validation (ECDSA)"""
 | 
			
		||||
        provider = SAMLProvider.objects.create(
 | 
			
		||||
            name=generate_id(),
 | 
			
		||||
            authorization_flow=self.flow,
 | 
			
		||||
            signing_kp=create_test_cert(PrivateKeyAlg.ECDSA),
 | 
			
		||||
            signature_algorithm=ECDSA_SHA256,
 | 
			
		||||
        )
 | 
			
		||||
        Application.objects.create(
 | 
			
		||||
            name=generate_id(),
 | 
			
		||||
 | 
			
		||||
@ -41,7 +41,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
 | 
			
		||||
        if not scim_group:
 | 
			
		||||
            self.logger.debug("Group does not exist in SCIM, skipping")
 | 
			
		||||
            return None
 | 
			
		||||
        response = self._request("DELETE", f"/Groups/{scim_group.id}")
 | 
			
		||||
        response = self._request("DELETE", f"/Groups/{scim_group.scim_id}")
 | 
			
		||||
        scim_group.delete()
 | 
			
		||||
        return response
 | 
			
		||||
 | 
			
		||||
@ -89,7 +89,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
 | 
			
		||||
        for user in connections:
 | 
			
		||||
            members.append(
 | 
			
		||||
                GroupMember(
 | 
			
		||||
                    value=user.id,
 | 
			
		||||
                    value=user.scim_id,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        if members:
 | 
			
		||||
@ -107,16 +107,19 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
 | 
			
		||||
                exclude_unset=True,
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        SCIMGroup.objects.create(provider=self.provider, group=group, id=response["id"])
 | 
			
		||||
        scim_id = response.get("id")
 | 
			
		||||
        if not scim_id or scim_id == "":
 | 
			
		||||
            raise StopSync("SCIM Response with missing or invalid `id`")
 | 
			
		||||
        SCIMGroup.objects.create(provider=self.provider, group=group, scim_id=scim_id)
 | 
			
		||||
 | 
			
		||||
    def _update(self, group: Group, connection: SCIMGroup):
 | 
			
		||||
        """Update existing group"""
 | 
			
		||||
        scim_group = self.to_scim(group)
 | 
			
		||||
        scim_group.id = connection.id
 | 
			
		||||
        scim_group.id = connection.scim_id
 | 
			
		||||
        try:
 | 
			
		||||
            return self._request(
 | 
			
		||||
                "PUT",
 | 
			
		||||
                f"/Groups/{scim_group.id}",
 | 
			
		||||
                f"/Groups/{connection.scim_id}",
 | 
			
		||||
                json=scim_group.model_dump(
 | 
			
		||||
                    mode="json",
 | 
			
		||||
                    exclude_unset=True,
 | 
			
		||||
@ -185,13 +188,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
 | 
			
		||||
            return
 | 
			
		||||
        user_ids = list(
 | 
			
		||||
            SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list(
 | 
			
		||||
                "id", flat=True
 | 
			
		||||
                "scim_id", flat=True
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        if len(user_ids) < 1:
 | 
			
		||||
            return
 | 
			
		||||
        self._patch(
 | 
			
		||||
            scim_group.id,
 | 
			
		||||
            scim_group.scim_id,
 | 
			
		||||
            PatchOperation(
 | 
			
		||||
                op=PatchOp.add,
 | 
			
		||||
                path="members",
 | 
			
		||||
@ -211,13 +214,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
 | 
			
		||||
            return
 | 
			
		||||
        user_ids = list(
 | 
			
		||||
            SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list(
 | 
			
		||||
                "id", flat=True
 | 
			
		||||
                "scim_id", flat=True
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        if len(user_ids) < 1:
 | 
			
		||||
            return
 | 
			
		||||
        self._patch(
 | 
			
		||||
            scim_group.id,
 | 
			
		||||
            scim_group.scim_id,
 | 
			
		||||
            PatchOperation(
 | 
			
		||||
                op=PatchOp.remove,
 | 
			
		||||
                path="members",
 | 
			
		||||
 | 
			
		||||
@ -9,13 +9,14 @@ from pydanticscim.service_provider import (
 | 
			
		||||
)
 | 
			
		||||
from pydanticscim.user import User as BaseUser
 | 
			
		||||
 | 
			
		||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
 | 
			
		||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class User(BaseUser):
 | 
			
		||||
    """Modified User schema with added externalId field"""
 | 
			
		||||
 | 
			
		||||
    schemas: list[str] = [
 | 
			
		||||
        "urn:ietf:params:scim:schemas:core:2.0:User",
 | 
			
		||||
    ]
 | 
			
		||||
    schemas: list[str] = [SCIM_USER_SCHEMA]
 | 
			
		||||
    externalId: str | None = None
 | 
			
		||||
    meta: dict | None = None
 | 
			
		||||
 | 
			
		||||
@ -23,9 +24,7 @@ class User(BaseUser):
 | 
			
		||||
class Group(BaseGroup):
 | 
			
		||||
    """Modified Group schema with added externalId field"""
 | 
			
		||||
 | 
			
		||||
    schemas: list[str] = [
 | 
			
		||||
        "urn:ietf:params:scim:schemas:core:2.0:Group",
 | 
			
		||||
    ]
 | 
			
		||||
    schemas: list[str] = [SCIM_GROUP_SCHEMA]
 | 
			
		||||
    externalId: str | None = None
 | 
			
		||||
    meta: dict | None = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -34,7 +34,7 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]):
 | 
			
		||||
        if not scim_user:
 | 
			
		||||
            self.logger.debug("User does not exist in SCIM, skipping")
 | 
			
		||||
            return None
 | 
			
		||||
        response = self._request("DELETE", f"/Users/{scim_user.id}")
 | 
			
		||||
        response = self._request("DELETE", f"/Users/{scim_user.scim_id}")
 | 
			
		||||
        scim_user.delete()
 | 
			
		||||
        return response
 | 
			
		||||
 | 
			
		||||
@ -85,15 +85,18 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]):
 | 
			
		||||
                exclude_unset=True,
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        SCIMUser.objects.create(provider=self.provider, user=user, id=response["id"])
 | 
			
		||||
        scim_id = response.get("id")
 | 
			
		||||
        if not scim_id or scim_id == "":
 | 
			
		||||
            raise StopSync("SCIM Response with missing or invalid `id`")
 | 
			
		||||
        SCIMUser.objects.create(provider=self.provider, user=user, scim_id=scim_id)
 | 
			
		||||
 | 
			
		||||
    def _update(self, user: User, connection: SCIMUser):
 | 
			
		||||
        """Update existing user"""
 | 
			
		||||
        scim_user = self.to_scim(user)
 | 
			
		||||
        scim_user.id = connection.id
 | 
			
		||||
        scim_user.id = connection.scim_id
 | 
			
		||||
        self._request(
 | 
			
		||||
            "PUT",
 | 
			
		||||
            f"/Users/{connection.id}",
 | 
			
		||||
            f"/Users/{connection.scim_id}",
 | 
			
		||||
            json=scim_user.model_dump(
 | 
			
		||||
                mode="json",
 | 
			
		||||
                exclude_unset=True,
 | 
			
		||||
 | 
			
		||||
@ -3,7 +3,7 @@
 | 
			
		||||
from structlog.stdlib import get_logger
 | 
			
		||||
 | 
			
		||||
from authentik.providers.scim.models import SCIMProvider
 | 
			
		||||
from authentik.providers.scim.tasks import scim_sync
 | 
			
		||||
from authentik.providers.scim.tasks import scim_task_wrapper
 | 
			
		||||
from authentik.tenants.management import TenantCommand
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
@ -21,4 +21,4 @@ class Command(TenantCommand):
 | 
			
		||||
            if not provider:
 | 
			
		||||
                LOGGER.warning("Provider does not exist", name=provider_name)
 | 
			
		||||
                continue
 | 
			
		||||
            scim_sync.delay(provider.pk).get()
 | 
			
		||||
            scim_task_wrapper(provider.pk).get()
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,76 @@
 | 
			
		||||
# Generated by Django 5.0.4 on 2024-05-03 12:38
 | 
			
		||||
 | 
			
		||||
import uuid
 | 
			
		||||
from django.db import migrations, models
 | 
			
		||||
from django.apps.registry import Apps
 | 
			
		||||
 | 
			
		||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
 | 
			
		||||
 | 
			
		||||
from authentik.lib.migrations import progress_bar
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fix_scim_user_group_pk(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
 | 
			
		||||
    SCIMUser = apps.get_model("authentik_providers_scim", "SCIMUser")
 | 
			
		||||
    SCIMGroup = apps.get_model("authentik_providers_scim", "SCIMGroup")
 | 
			
		||||
    db_alias = schema_editor.connection.alias
 | 
			
		||||
    print("\nFixing primary key for SCIM users, this might take a couple of minutes...")
 | 
			
		||||
    for user in progress_bar(SCIMUser.objects.using(db_alias).all()):
 | 
			
		||||
        SCIMUser.objects.using(db_alias).filter(
 | 
			
		||||
            pk=user.pk, user=user.user_id, provider=user.provider_id
 | 
			
		||||
        ).update(scim_id=user.pk, id=uuid.uuid4())
 | 
			
		||||
 | 
			
		||||
    print("\nFixing primary key for SCIM groups, this might take a couple of minutes...")
 | 
			
		||||
    for group in progress_bar(SCIMGroup.objects.using(db_alias).all()):
 | 
			
		||||
        SCIMGroup.objects.using(db_alias).filter(
 | 
			
		||||
            pk=group.pk, group=group.group_id, provider=group.provider_id
 | 
			
		||||
        ).update(scim_id=group.pk, id=uuid.uuid4())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Migration(migrations.Migration):
 | 
			
		||||
 | 
			
		||||
    dependencies = [
 | 
			
		||||
        (
 | 
			
		||||
            "authentik_providers_scim",
 | 
			
		||||
            "0001_squashed_0006_rename_parent_group_scimprovider_filter_group",
 | 
			
		||||
        ),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    operations = [
 | 
			
		||||
        migrations.AddField(
 | 
			
		||||
            model_name="scimgroup",
 | 
			
		||||
            name="scim_id",
 | 
			
		||||
            field=models.TextField(default="temp"),
 | 
			
		||||
            preserve_default=False,
 | 
			
		||||
        ),
 | 
			
		||||
        migrations.AddField(
 | 
			
		||||
            model_name="scimuser",
 | 
			
		||||
            name="scim_id",
 | 
			
		||||
            field=models.TextField(default="temp"),
 | 
			
		||||
            preserve_default=False,
 | 
			
		||||
        ),
 | 
			
		||||
        migrations.RunPython(fix_scim_user_group_pk),
 | 
			
		||||
        migrations.AlterField(
 | 
			
		||||
            model_name="scimgroup",
 | 
			
		||||
            name="id",
 | 
			
		||||
            field=models.UUIDField(
 | 
			
		||||
                default=uuid.uuid4, editable=False, primary_key=True, serialize=False
 | 
			
		||||
            ),
 | 
			
		||||
        ),
 | 
			
		||||
        migrations.AlterField(
 | 
			
		||||
            model_name="scimuser",
 | 
			
		||||
            name="id",
 | 
			
		||||
            field=models.UUIDField(
 | 
			
		||||
                default=uuid.uuid4, editable=False, primary_key=True, serialize=False
 | 
			
		||||
            ),
 | 
			
		||||
        ),
 | 
			
		||||
        migrations.AlterField(model_name="scimuser", name="scim_id", field=models.TextField()),
 | 
			
		||||
        migrations.AlterField(model_name="scimgroup", name="scim_id", field=models.TextField()),
 | 
			
		||||
        migrations.AlterUniqueTogether(
 | 
			
		||||
            name="scimgroup",
 | 
			
		||||
            unique_together={("scim_id", "group", "provider")},
 | 
			
		||||
        ),
 | 
			
		||||
        migrations.AlterUniqueTogether(
 | 
			
		||||
            name="scimuser",
 | 
			
		||||
            unique_together={("scim_id", "user", "provider")},
 | 
			
		||||
        ),
 | 
			
		||||
    ]
 | 
			
		||||
@ -1,5 +1,7 @@
 | 
			
		||||
"""SCIM Provider models"""
 | 
			
		||||
 | 
			
		||||
from uuid import uuid4
 | 
			
		||||
 | 
			
		||||
from django.core.cache import cache
 | 
			
		||||
from django.db import models
 | 
			
		||||
from django.db.models import QuerySet
 | 
			
		||||
@ -97,26 +99,28 @@ class SCIMMapping(PropertyMapping):
 | 
			
		||||
class SCIMUser(models.Model):
 | 
			
		||||
    """Mapping of a user and provider to a SCIM user ID"""
 | 
			
		||||
 | 
			
		||||
    id = models.TextField(primary_key=True)
 | 
			
		||||
    id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
 | 
			
		||||
    scim_id = models.TextField()
 | 
			
		||||
    user = models.ForeignKey(User, on_delete=models.CASCADE)
 | 
			
		||||
    provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE)
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        unique_together = (("id", "user", "provider"),)
 | 
			
		||||
        unique_together = (("scim_id", "user", "provider"),)
 | 
			
		||||
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        return f"SCIM User {self.user.username} to {self.provider.name}"
 | 
			
		||||
        return f"SCIM User {self.user_id} to {self.provider_id}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SCIMGroup(models.Model):
 | 
			
		||||
    """Mapping of a group and provider to a SCIM user ID"""
 | 
			
		||||
 | 
			
		||||
    id = models.TextField(primary_key=True)
 | 
			
		||||
    id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
 | 
			
		||||
    scim_id = models.TextField()
 | 
			
		||||
    group = models.ForeignKey(Group, on_delete=models.CASCADE)
 | 
			
		||||
    provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE)
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        unique_together = (("id", "group", "provider"),)
 | 
			
		||||
        unique_together = (("scim_id", "group", "provider"),)
 | 
			
		||||
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        return f"SCIM Group {self.group.name} to {self.provider.name}"
 | 
			
		||||
        return f"SCIM Group {self.group_id} to {self.provider_id}"
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,7 @@ from structlog.stdlib import get_logger
 | 
			
		||||
from authentik.core.models import Group, User
 | 
			
		||||
from authentik.lib.utils.reflection import class_to_path
 | 
			
		||||
from authentik.providers.scim.models import SCIMProvider
 | 
			
		||||
from authentik.providers.scim.tasks import scim_signal_direct, scim_signal_m2m, scim_sync
 | 
			
		||||
from authentik.providers.scim.tasks import scim_signal_direct, scim_signal_m2m, scim_task_wrapper
 | 
			
		||||
 | 
			
		||||
LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,7 @@ LOGGER = get_logger()
 | 
			
		||||
@receiver(post_save, sender=SCIMProvider)
 | 
			
		||||
def post_save_provider(sender: type[Model], instance, created: bool, **_):
 | 
			
		||||
    """Trigger sync when SCIM provider is saved"""
 | 
			
		||||
    scim_sync.delay(instance.pk)
 | 
			
		||||
    scim_task_wrapper(instance.pk)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@receiver(post_save, sender=User)
 | 
			
		||||
 | 
			
		||||
@ -38,7 +38,23 @@ def client_for_model(provider: SCIMProvider, model: Model) -> SCIMClient:
 | 
			
		||||
def scim_sync_all():
 | 
			
		||||
    """Run sync for all providers"""
 | 
			
		||||
    for provider in SCIMProvider.objects.filter(backchannel_application__isnull=False):
 | 
			
		||||
        scim_sync.delay(provider.pk)
 | 
			
		||||
        scim_task_wrapper(provider.pk)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def scim_task_wrapper(provider_pk: int):
 | 
			
		||||
    """Wrap scim_sync to set the correct timeouts"""
 | 
			
		||||
    provider: SCIMProvider = SCIMProvider.objects.filter(
 | 
			
		||||
        pk=provider_pk, backchannel_application__isnull=False
 | 
			
		||||
    ).first()
 | 
			
		||||
    if not provider:
 | 
			
		||||
        return
 | 
			
		||||
    users_paginator = Paginator(provider.get_user_qs(), PAGE_SIZE)
 | 
			
		||||
    groups_paginator = Paginator(provider.get_group_qs(), PAGE_SIZE)
 | 
			
		||||
    soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT
 | 
			
		||||
    time_limit = soft_time_limit * 1.5
 | 
			
		||||
    return scim_sync.apply_async(
 | 
			
		||||
        (provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@CELERY_APP.task(bind=True, base=SystemTask)
 | 
			
		||||
@ -60,7 +76,7 @@ def scim_sync(self: SystemTask, provider_pk: int) -> None:
 | 
			
		||||
    users_paginator = Paginator(provider.get_user_qs(), PAGE_SIZE)
 | 
			
		||||
    groups_paginator = Paginator(provider.get_group_qs(), PAGE_SIZE)
 | 
			
		||||
    self.soft_time_limit = self.time_limit = (
 | 
			
		||||
        users_paginator.count + groups_paginator.count
 | 
			
		||||
        users_paginator.num_pages + groups_paginator.num_pages
 | 
			
		||||
    ) * PAGE_TIMEOUT
 | 
			
		||||
    with allow_join_result():
 | 
			
		||||
        try:
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,7 @@ from authentik.core.models import Application, Group, User
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
from authentik.providers.scim.clients.schema import ServiceProviderConfiguration
 | 
			
		||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider
 | 
			
		||||
from authentik.providers.scim.tasks import scim_sync
 | 
			
		||||
from authentik.providers.scim.tasks import scim_task_wrapper
 | 
			
		||||
from authentik.tenants.models import Tenant
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -79,7 +79,7 @@ class SCIMMembershipTests(TestCase):
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self.configure()
 | 
			
		||||
            scim_sync.delay(self.provider.pk).get()
 | 
			
		||||
            scim_task_wrapper(self.provider.pk).get()
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(mocker.call_count, 6)
 | 
			
		||||
            self.assertEqual(mocker.request_history[0].method, "GET")
 | 
			
		||||
@ -169,7 +169,7 @@ class SCIMMembershipTests(TestCase):
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self.configure()
 | 
			
		||||
            scim_sync.delay(self.provider.pk).get()
 | 
			
		||||
            scim_task_wrapper(self.provider.pk).get()
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(mocker.call_count, 6)
 | 
			
		||||
            self.assertEqual(mocker.request_history[0].method, "GET")
 | 
			
		||||
 | 
			
		||||
@ -10,7 +10,7 @@ from authentik.blueprints.tests import apply_blueprint
 | 
			
		||||
from authentik.core.models import Application, Group, User
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
from authentik.providers.scim.models import SCIMMapping, SCIMProvider
 | 
			
		||||
from authentik.providers.scim.tasks import scim_sync
 | 
			
		||||
from authentik.providers.scim.tasks import scim_task_wrapper
 | 
			
		||||
from authentik.tenants.models import Tenant
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -88,6 +88,72 @@ class SCIMUserTests(TestCase):
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @Mocker()
 | 
			
		||||
    def test_user_create_different_provider_same_id(self, mock: Mocker):
 | 
			
		||||
        """Test user creation with multiple providers that happen
 | 
			
		||||
        to return the same object ID"""
 | 
			
		||||
        # Create duplicate provider
 | 
			
		||||
        provider: SCIMProvider = SCIMProvider.objects.create(
 | 
			
		||||
            name=generate_id(),
 | 
			
		||||
            url="https://localhost",
 | 
			
		||||
            token=generate_id(),
 | 
			
		||||
            exclude_users_service_account=True,
 | 
			
		||||
        )
 | 
			
		||||
        app: Application = Application.objects.create(
 | 
			
		||||
            name=generate_id(),
 | 
			
		||||
            slug=generate_id(),
 | 
			
		||||
        )
 | 
			
		||||
        app.backchannel_providers.add(provider)
 | 
			
		||||
        provider.property_mappings.add(
 | 
			
		||||
            SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user")
 | 
			
		||||
        )
 | 
			
		||||
        provider.property_mappings_group.add(
 | 
			
		||||
            SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group")
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        scim_id = generate_id()
 | 
			
		||||
        mock.get(
 | 
			
		||||
            "https://localhost/ServiceProviderConfig",
 | 
			
		||||
            json={},
 | 
			
		||||
        )
 | 
			
		||||
        mock.post(
 | 
			
		||||
            "https://localhost/Users",
 | 
			
		||||
            json={
 | 
			
		||||
                "id": scim_id,
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
        uid = generate_id()
 | 
			
		||||
        user = User.objects.create(
 | 
			
		||||
            username=uid,
 | 
			
		||||
            name=f"{uid} {uid}",
 | 
			
		||||
            email=f"{uid}@goauthentik.io",
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(mock.call_count, 4)
 | 
			
		||||
        self.assertEqual(mock.request_history[0].method, "GET")
 | 
			
		||||
        self.assertEqual(mock.request_history[1].method, "POST")
 | 
			
		||||
        self.assertJSONEqual(
 | 
			
		||||
            mock.request_history[1].body,
 | 
			
		||||
            {
 | 
			
		||||
                "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
 | 
			
		||||
                "active": True,
 | 
			
		||||
                "emails": [
 | 
			
		||||
                    {
 | 
			
		||||
                        "primary": True,
 | 
			
		||||
                        "type": "other",
 | 
			
		||||
                        "value": f"{uid}@goauthentik.io",
 | 
			
		||||
                    }
 | 
			
		||||
                ],
 | 
			
		||||
                "externalId": user.uid,
 | 
			
		||||
                "name": {
 | 
			
		||||
                    "familyName": uid,
 | 
			
		||||
                    "formatted": f"{uid} {uid}",
 | 
			
		||||
                    "givenName": uid,
 | 
			
		||||
                },
 | 
			
		||||
                "displayName": f"{uid} {uid}",
 | 
			
		||||
                "userName": uid,
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @Mocker()
 | 
			
		||||
    def test_user_create_update(self, mock: Mocker):
 | 
			
		||||
        """Test user creation and update"""
 | 
			
		||||
@ -236,7 +302,7 @@ class SCIMUserTests(TestCase):
 | 
			
		||||
            email=f"{uid}@goauthentik.io",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        scim_sync.delay(self.provider.pk).get()
 | 
			
		||||
        scim_task_wrapper(self.provider.pk).get()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(mock.call_count, 5)
 | 
			
		||||
        self.assertEqual(mock.request_history[0].method, "GET")
 | 
			
		||||
 | 
			
		||||
@ -25,7 +25,7 @@ class ObjectFilter(ObjectPermissionsFilter):
 | 
			
		||||
        # Outposts (which are the only objects using internal service accounts)
 | 
			
		||||
        # except requests to return an empty list when they have no objects
 | 
			
		||||
        # assigned
 | 
			
		||||
        if request.user.type == UserTypes.INTERNAL_SERVICE_ACCOUNT:
 | 
			
		||||
        if getattr(request.user, "type", UserTypes.INTERNAL) == UserTypes.INTERNAL_SERVICE_ACCOUNT:
 | 
			
		||||
            return queryset
 | 
			
		||||
        if not queryset.exists():
 | 
			
		||||
            # User doesn't have direct permission to all objects
 | 
			
		||||
 | 
			
		||||
@ -376,7 +376,13 @@ CELERY = {
 | 
			
		||||
    "task_default_queue": "authentik",
 | 
			
		||||
    "broker_url": CONFIG.get("broker.url") or redis_url(CONFIG.get("redis.db")),
 | 
			
		||||
    "result_backend": CONFIG.get("result_backend.url") or redis_url(CONFIG.get("redis.db")),
 | 
			
		||||
    "broker_transport_options": CONFIG.get_dict_from_b64_json("broker.transport_options"),
 | 
			
		||||
    "broker_transport_options": CONFIG.get_dict_from_b64_json(
 | 
			
		||||
        "broker.transport_options", {"retry_policy": {"timeout": 5.0}}
 | 
			
		||||
    ),
 | 
			
		||||
    "result_backend_transport_options": CONFIG.get_dict_from_b64_json(
 | 
			
		||||
        "result_backend.transport_options", {"retry_policy": {"timeout": 5.0}}
 | 
			
		||||
    ),
 | 
			
		||||
    "redis_retry_on_timeout": True,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Sentry integration
 | 
			
		||||
 | 
			
		||||
@ -80,7 +80,7 @@ class OAuth2Client(BaseOAuthClient):
 | 
			
		||||
            access_token_url = self.source.source_type.access_token_url or ""
 | 
			
		||||
            if self.source.source_type.urls_customizable and self.source.access_token_url:
 | 
			
		||||
                access_token_url = self.source.access_token_url
 | 
			
		||||
            response = self.session.request(
 | 
			
		||||
            response = self.do_request(
 | 
			
		||||
                "post", access_token_url, data=args, headers=self._default_headers, **request_kwargs
 | 
			
		||||
            )
 | 
			
		||||
            response.raise_for_status()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										37
									
								
								authentik/sources/oauth/tests/test_type_apple.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								authentik/sources/oauth/tests/test_type_apple.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,37 @@
 | 
			
		||||
"""Apple Type tests"""
 | 
			
		||||
 | 
			
		||||
from django.test import RequestFactory, TestCase
 | 
			
		||||
from guardian.shortcuts import get_anonymous_user
 | 
			
		||||
 | 
			
		||||
from authentik.lib.generators import generate_id
 | 
			
		||||
from authentik.lib.tests.utils import dummy_get_response
 | 
			
		||||
from authentik.root.middleware import SessionMiddleware
 | 
			
		||||
from authentik.sources.oauth.models import OAuthSource
 | 
			
		||||
from authentik.sources.oauth.types.registry import registry
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestTypeApple(TestCase):
 | 
			
		||||
    """OAuth Source tests"""
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.source = OAuthSource.objects.create(
 | 
			
		||||
            name="test",
 | 
			
		||||
            slug="test",
 | 
			
		||||
            provider_type="apple",
 | 
			
		||||
            authorization_url="",
 | 
			
		||||
            profile_url="",
 | 
			
		||||
            consumer_key=generate_id(),
 | 
			
		||||
        )
 | 
			
		||||
        self.factory = RequestFactory()
 | 
			
		||||
 | 
			
		||||
    def test_login_challenge(self):
 | 
			
		||||
        """Test login_challenge"""
 | 
			
		||||
        request = self.factory.get("/")
 | 
			
		||||
        request.user = get_anonymous_user()
 | 
			
		||||
 | 
			
		||||
        middleware = SessionMiddleware(dummy_get_response)
 | 
			
		||||
        middleware.process_request(request)
 | 
			
		||||
        request.session.save()
 | 
			
		||||
        oauth_type = registry.find_type("apple")
 | 
			
		||||
        challenge = oauth_type().login_challenge(self.source, request)
 | 
			
		||||
        self.assertTrue(challenge.is_valid(raise_exception=True))
 | 
			
		||||
@ -125,7 +125,7 @@ class AppleType(SourceType):
 | 
			
		||||
        )
 | 
			
		||||
        args = apple_client.get_redirect_args()
 | 
			
		||||
        return AppleLoginChallenge(
 | 
			
		||||
            instance={
 | 
			
		||||
            data={
 | 
			
		||||
                "client_id": apple_client.get_client_id(),
 | 
			
		||||
                "scope": "name email",
 | 
			
		||||
                "redirect_uri": args["redirect_uri"],
 | 
			
		||||
 | 
			
		||||
@ -66,7 +66,7 @@ class PlexSource(Source):
 | 
			
		||||
            icon = static("authentik/sources/plex.svg")
 | 
			
		||||
        return UILoginButton(
 | 
			
		||||
            challenge=PlexAuthenticationChallenge(
 | 
			
		||||
                {
 | 
			
		||||
                data={
 | 
			
		||||
                    "type": ChallengeTypes.NATIVE.value,
 | 
			
		||||
                    "component": "ak-source-plex",
 | 
			
		||||
                    "client_id": self.client_id,
 | 
			
		||||
 | 
			
		||||
@ -40,6 +40,11 @@ class TestPlexSource(TestCase):
 | 
			
		||||
            slug="test",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_login_challenge(self):
 | 
			
		||||
        """Test login_challenge"""
 | 
			
		||||
        ui_login_button = self.source.ui_login_button(None)
 | 
			
		||||
        self.assertTrue(ui_login_button.challenge.is_valid(raise_exception=True))
 | 
			
		||||
 | 
			
		||||
    def test_get_user_info(self):
 | 
			
		||||
        """Test get_user_info"""
 | 
			
		||||
        token = generate_key()
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,44 @@
 | 
			
		||||
# Generated by Django 5.0.4 on 2024-05-01 15:44
 | 
			
		||||
 | 
			
		||||
from django.db import migrations, models
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Migration(migrations.Migration):
 | 
			
		||||
 | 
			
		||||
    dependencies = [
 | 
			
		||||
        ("authentik_sources_saml", "0013_samlsource_verification_kp_and_more"),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    operations = [
 | 
			
		||||
        migrations.AlterField(
 | 
			
		||||
            model_name="samlsource",
 | 
			
		||||
            name="digest_algorithm",
 | 
			
		||||
            field=models.TextField(
 | 
			
		||||
                choices=[
 | 
			
		||||
                    ("http://www.w3.org/2000/09/xmldsig#sha1", "SHA1"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmlenc#sha256", "SHA256"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmldsig-more#sha384", "SHA384"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmlenc#sha512", "SHA512"),
 | 
			
		||||
                ],
 | 
			
		||||
                default="http://www.w3.org/2001/04/xmlenc#sha256",
 | 
			
		||||
            ),
 | 
			
		||||
        ),
 | 
			
		||||
        migrations.AlterField(
 | 
			
		||||
            model_name="samlsource",
 | 
			
		||||
            name="signature_algorithm",
 | 
			
		||||
            field=models.TextField(
 | 
			
		||||
                choices=[
 | 
			
		||||
                    ("http://www.w3.org/2000/09/xmldsig#rsa-sha1", "RSA-SHA1"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmldsig-more#rsa-sha256", "RSA-SHA256"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmldsig-more#rsa-sha384", "RSA-SHA384"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmldsig-more#rsa-sha512", "RSA-SHA512"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1", "ECDSA-SHA1"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256", "ECDSA-SHA256"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384", "ECDSA-SHA384"),
 | 
			
		||||
                    ("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512", "ECDSA-SHA512"),
 | 
			
		||||
                    ("http://www.w3.org/2000/09/xmldsig#dsa-sha1", "DSA-SHA1"),
 | 
			
		||||
                ],
 | 
			
		||||
                default="http://www.w3.org/2001/04/xmldsig-more#rsa-sha256",
 | 
			
		||||
            ),
 | 
			
		||||
        ),
 | 
			
		||||
    ]
 | 
			
		||||
@ -15,6 +15,10 @@ from authentik.flows.models import Flow
 | 
			
		||||
from authentik.lib.utils.time import timedelta_string_validator
 | 
			
		||||
from authentik.sources.saml.processors.constants import (
 | 
			
		||||
    DSA_SHA1,
 | 
			
		||||
    ECDSA_SHA1,
 | 
			
		||||
    ECDSA_SHA256,
 | 
			
		||||
    ECDSA_SHA384,
 | 
			
		||||
    ECDSA_SHA512,
 | 
			
		||||
    RSA_SHA1,
 | 
			
		||||
    RSA_SHA256,
 | 
			
		||||
    RSA_SHA384,
 | 
			
		||||
@ -143,8 +147,7 @@ class SAMLSource(Source):
 | 
			
		||||
        verbose_name=_("Signing Keypair"),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    digest_algorithm = models.CharField(
 | 
			
		||||
        max_length=50,
 | 
			
		||||
    digest_algorithm = models.TextField(
 | 
			
		||||
        choices=(
 | 
			
		||||
            (SHA1, _("SHA1")),
 | 
			
		||||
            (SHA256, _("SHA256")),
 | 
			
		||||
@ -153,13 +156,16 @@ class SAMLSource(Source):
 | 
			
		||||
        ),
 | 
			
		||||
        default=SHA256,
 | 
			
		||||
    )
 | 
			
		||||
    signature_algorithm = models.CharField(
 | 
			
		||||
        max_length=50,
 | 
			
		||||
    signature_algorithm = models.TextField(
 | 
			
		||||
        choices=(
 | 
			
		||||
            (RSA_SHA1, _("RSA-SHA1")),
 | 
			
		||||
            (RSA_SHA256, _("RSA-SHA256")),
 | 
			
		||||
            (RSA_SHA384, _("RSA-SHA384")),
 | 
			
		||||
            (RSA_SHA512, _("RSA-SHA512")),
 | 
			
		||||
            (ECDSA_SHA1, _("ECDSA-SHA1")),
 | 
			
		||||
            (ECDSA_SHA256, _("ECDSA-SHA256")),
 | 
			
		||||
            (ECDSA_SHA384, _("ECDSA-SHA384")),
 | 
			
		||||
            (ECDSA_SHA512, _("ECDSA-SHA512")),
 | 
			
		||||
            (DSA_SHA1, _("DSA-SHA1")),
 | 
			
		||||
        ),
 | 
			
		||||
        default=RSA_SHA256,
 | 
			
		||||
 | 
			
		||||
@ -26,9 +26,16 @@ SAML_BINDING_REDIRECT = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
 | 
			
		||||
 | 
			
		||||
DSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#dsa-sha1"
 | 
			
		||||
RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
 | 
			
		||||
# https://datatracker.ietf.org/doc/html/rfc4051#section-2.3.2
 | 
			
		||||
RSA_SHA256 = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256"
 | 
			
		||||
RSA_SHA384 = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha384"
 | 
			
		||||
RSA_SHA512 = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha512"
 | 
			
		||||
# https://datatracker.ietf.org/doc/html/rfc4051#section-2.3.6
 | 
			
		||||
ECDSA_SHA1 = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1"
 | 
			
		||||
ECDSA_SHA224 = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha224"
 | 
			
		||||
ECDSA_SHA256 = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256"
 | 
			
		||||
ECDSA_SHA384 = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384"
 | 
			
		||||
ECDSA_SHA512 = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512"
 | 
			
		||||
 | 
			
		||||
SHA1 = "http://www.w3.org/2000/09/xmldsig#sha1"
 | 
			
		||||
SHA256 = "http://www.w3.org/2001/04/xmlenc#sha256"
 | 
			
		||||
@ -41,6 +48,11 @@ SIGN_ALGORITHM_TRANSFORM_MAP = {
 | 
			
		||||
    RSA_SHA256: xmlsec.constants.TransformRsaSha256,
 | 
			
		||||
    RSA_SHA384: xmlsec.constants.TransformRsaSha384,
 | 
			
		||||
    RSA_SHA512: xmlsec.constants.TransformRsaSha512,
 | 
			
		||||
    ECDSA_SHA1: xmlsec.constants.TransformEcdsaSha1,
 | 
			
		||||
    ECDSA_SHA224: xmlsec.constants.TransformEcdsaSha224,
 | 
			
		||||
    ECDSA_SHA256: xmlsec.constants.TransformEcdsaSha256,
 | 
			
		||||
    ECDSA_SHA384: xmlsec.constants.TransformEcdsaSha384,
 | 
			
		||||
    ECDSA_SHA512: xmlsec.constants.TransformEcdsaSha512,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
DIGEST_ALGORITHM_TRANSLATION_MAP = {
 | 
			
		||||
 | 
			
		||||
@ -60,7 +60,7 @@ class SCIMSourceUser(SerializerModel):
 | 
			
		||||
        unique_together = (("id", "user", "source"),)
 | 
			
		||||
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        return f"SCIM User {self.user.username} to {self.source.name}"
 | 
			
		||||
        return f"SCIM User {self.user_id} to {self.source_id}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SCIMSourceGroup(SerializerModel):
 | 
			
		||||
@ -81,4 +81,4 @@ class SCIMSourceGroup(SerializerModel):
 | 
			
		||||
        unique_together = (("id", "group", "source"),)
 | 
			
		||||
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        return f"SCIM Group {self.group.name} to {self.source.name}"
 | 
			
		||||
        return f"SCIM Group {self.group_id} to {self.source_id}"
 | 
			
		||||
 | 
			
		||||
@ -2,9 +2,11 @@ from django.db.models import Model
 | 
			
		||||
from django.db.models.signals import pre_delete, pre_save
 | 
			
		||||
from django.dispatch import receiver
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import Token, TokenIntents, User, UserTypes
 | 
			
		||||
from authentik.core.models import USER_PATH_SYSTEM_PREFIX, Token, TokenIntents, User, UserTypes
 | 
			
		||||
from authentik.sources.scim.models import SCIMSource
 | 
			
		||||
 | 
			
		||||
USER_PATH_SOURCE_SCIM = USER_PATH_SYSTEM_PREFIX + "/sources/scim"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@receiver(pre_save, sender=SCIMSource)
 | 
			
		||||
def scim_source_pre_save(sender: type[Model], instance: SCIMSource, **_):
 | 
			
		||||
@ -16,6 +18,7 @@ def scim_source_pre_save(sender: type[Model], instance: SCIMSource, **_):
 | 
			
		||||
        username=identifier,
 | 
			
		||||
        name=f"SCIM Source {instance.name} Service-Account",
 | 
			
		||||
        type=UserTypes.INTERNAL_SERVICE_ACCOUNT,
 | 
			
		||||
        path=USER_PATH_SOURCE_SCIM,
 | 
			
		||||
    )
 | 
			
		||||
    token = Token.objects.create(
 | 
			
		||||
        user=user,
 | 
			
		||||
 | 
			
		||||
@ -13,6 +13,7 @@ from rest_framework.request import Request
 | 
			
		||||
from rest_framework.response import Response
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import Group, User
 | 
			
		||||
from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
 | 
			
		||||
from authentik.providers.scim.clients.schema import Group as SCIMGroupModel
 | 
			
		||||
from authentik.sources.scim.models import SCIMSourceGroup
 | 
			
		||||
from authentik.sources.scim.views.v2.base import SCIMView
 | 
			
		||||
@ -26,9 +27,11 @@ class GroupsView(SCIMView):
 | 
			
		||||
    def group_to_scim(self, scim_group: SCIMSourceGroup) -> dict:
 | 
			
		||||
        """Convert Group to SCIM data"""
 | 
			
		||||
        payload = SCIMGroupModel(
 | 
			
		||||
            schemas=[SCIM_USER_SCHEMA],
 | 
			
		||||
            id=str(scim_group.group.pk),
 | 
			
		||||
            externalId=scim_group.id,
 | 
			
		||||
            displayName=scim_group.group.name,
 | 
			
		||||
            members=[],
 | 
			
		||||
            meta={
 | 
			
		||||
                "resourceType": "Group",
 | 
			
		||||
                "location": self.request.build_absolute_uri(
 | 
			
		||||
@ -42,28 +45,24 @@ class GroupsView(SCIMView):
 | 
			
		||||
                ),
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
        return payload.model_dump(
 | 
			
		||||
            mode="json",
 | 
			
		||||
            exclude_unset=True,
 | 
			
		||||
        )
 | 
			
		||||
        for member in scim_group.group.users.order_by("pk"):
 | 
			
		||||
            member: User
 | 
			
		||||
            payload.members.append(GroupMember(value=str(member.uuid)))
 | 
			
		||||
        return payload.model_dump(mode="json", exclude_unset=True)
 | 
			
		||||
 | 
			
		||||
    def get(self, request: Request, group_id: str | None = None, **kwargs) -> Response:
 | 
			
		||||
        """List Group handler"""
 | 
			
		||||
        base_query = SCIMSourceGroup.objects.select_related("group").prefetch_related(
 | 
			
		||||
            "group__users"
 | 
			
		||||
        )
 | 
			
		||||
        if group_id:
 | 
			
		||||
            connection = (
 | 
			
		||||
                SCIMSourceGroup.objects.filter(source=self.source, group__group_uuid=group_id)
 | 
			
		||||
                .select_related("group")
 | 
			
		||||
                .first()
 | 
			
		||||
            )
 | 
			
		||||
            connection = base_query.filter(source=self.source, group__group_uuid=group_id).first()
 | 
			
		||||
            if not connection:
 | 
			
		||||
                raise Http404
 | 
			
		||||
            return Response(self.group_to_scim(connection))
 | 
			
		||||
        connections = (
 | 
			
		||||
            SCIMSourceGroup.objects.filter(source=self.source)
 | 
			
		||||
            .select_related("group")
 | 
			
		||||
            .order_by("pk")
 | 
			
		||||
            base_query.filter(source=self.source).order_by("pk").filter(self.filter_parse(request))
 | 
			
		||||
        )
 | 
			
		||||
        connections = connections.filter(self.filter_parse(request))
 | 
			
		||||
        page = self.paginate_query(connections)
 | 
			
		||||
        return Response(
 | 
			
		||||
            {
 | 
			
		||||
@ -79,6 +78,8 @@ class GroupsView(SCIMView):
 | 
			
		||||
    def update_group(self, connection: SCIMSourceGroup | None, data: QueryDict):
 | 
			
		||||
        """Partial update a group"""
 | 
			
		||||
        group = connection.group if connection else Group()
 | 
			
		||||
        if _group := Group.objects.filter(name=data.get("displayName")).first():
 | 
			
		||||
            group = _group
 | 
			
		||||
        if "displayName" in data:
 | 
			
		||||
            group.name = data.get("displayName")
 | 
			
		||||
        if group.name == "":
 | 
			
		||||
 | 
			
		||||
@ -11,6 +11,7 @@ from rest_framework.request import Request
 | 
			
		||||
from rest_framework.response import Response
 | 
			
		||||
 | 
			
		||||
from authentik.core.models import User
 | 
			
		||||
from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA
 | 
			
		||||
from authentik.providers.scim.clients.schema import User as SCIMUserModel
 | 
			
		||||
from authentik.sources.scim.models import SCIMSourceUser
 | 
			
		||||
from authentik.sources.scim.views.v2.base import SCIMView
 | 
			
		||||
@ -33,6 +34,7 @@ class UsersView(SCIMView):
 | 
			
		||||
    def user_to_scim(self, scim_user: SCIMSourceUser) -> dict:
 | 
			
		||||
        """Convert User to SCIM data"""
 | 
			
		||||
        payload = SCIMUserModel(
 | 
			
		||||
            schemas=[SCIM_USER_SCHEMA],
 | 
			
		||||
            id=str(scim_user.user.uuid),
 | 
			
		||||
            externalId=scim_user.id,
 | 
			
		||||
            userName=scim_user.user.username,
 | 
			
		||||
@ -62,10 +64,7 @@ class UsersView(SCIMView):
 | 
			
		||||
                ),
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
        final_payload = payload.model_dump(
 | 
			
		||||
            mode="json",
 | 
			
		||||
            exclude_unset=True,
 | 
			
		||||
        )
 | 
			
		||||
        final_payload = payload.model_dump(mode="json", exclude_unset=True)
 | 
			
		||||
        final_payload.update(scim_user.attributes)
 | 
			
		||||
        return final_payload
 | 
			
		||||
 | 
			
		||||
@ -99,6 +98,8 @@ class UsersView(SCIMView):
 | 
			
		||||
    def update_user(self, connection: SCIMSourceUser | None, data: QueryDict):
 | 
			
		||||
        """Partial update a user"""
 | 
			
		||||
        user = connection.user if connection else User()
 | 
			
		||||
        if _user := User.objects.filter(username=data.get("userName")).first():
 | 
			
		||||
            user = _user
 | 
			
		||||
        user.path = self.source.get_user_path()
 | 
			
		||||
        if "userName" in data:
 | 
			
		||||
            user.username = data.get("userName")
 | 
			
		||||
 | 
			
		||||
@ -96,7 +96,7 @@ class DuoDevice(SerializerModel, Device):
 | 
			
		||||
        return DuoDeviceSerializer
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return str(self.name) or str(self.user)
 | 
			
		||||
        return str(self.name) or str(self.user_id)
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        verbose_name = _("Duo Device")
 | 
			
		||||
 | 
			
		||||
@ -221,7 +221,7 @@ class SMSDevice(SerializerModel, SideChannelDevice):
 | 
			
		||||
        return valid
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return str(self.name) or str(self.user)
 | 
			
		||||
        return str(self.name) or str(self.user_id)
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        verbose_name = _("SMS Device")
 | 
			
		||||
 | 
			
		||||
@ -155,7 +155,7 @@ class WebAuthnDevice(SerializerModel, Device):
 | 
			
		||||
        return WebAuthnDeviceSerializer
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return str(self.name) or str(self.user)
 | 
			
		||||
        return str(self.name) or str(self.user_id)
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        verbose_name = _("WebAuthn Device")
 | 
			
		||||
 | 
			
		||||
@ -65,7 +65,7 @@ class UserConsent(SerializerModel, ExpiringModel):
 | 
			
		||||
        return UserConsentSerializer
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f"User Consent {self.application} by {self.user}"
 | 
			
		||||
        return f"User Consent {self.application_id} by {self.user_id}"
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        unique_together = (("user", "application", "permissions"),)
 | 
			
		||||
 | 
			
		||||
@ -79,7 +79,7 @@ class Invitation(SerializerModel, ExpiringModel):
 | 
			
		||||
        return InvitationSerializer
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return f"Invitation {str(self.invite_uuid)} created by {self.created_by}"
 | 
			
		||||
        return f"Invitation {str(self.invite_uuid)} created by {self.created_by_id}"
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        verbose_name = _("Invitation")
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,23 @@
 | 
			
		||||
# Generated by Django 5.0.4 on 2024-05-01 15:32
 | 
			
		||||
 | 
			
		||||
import authentik.lib.utils.time
 | 
			
		||||
from django.db import migrations, models
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Migration(migrations.Migration):
 | 
			
		||||
 | 
			
		||||
    dependencies = [
 | 
			
		||||
        ("authentik_tenants", "0002_tenant_default_token_duration_and_more"),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    operations = [
 | 
			
		||||
        migrations.AlterField(
 | 
			
		||||
            model_name="tenant",
 | 
			
		||||
            name="default_token_duration",
 | 
			
		||||
            field=models.TextField(
 | 
			
		||||
                default="days=1",
 | 
			
		||||
                help_text="Default token duration",
 | 
			
		||||
                validators=[authentik.lib.utils.time.timedelta_string_validator],
 | 
			
		||||
            ),
 | 
			
		||||
        ),
 | 
			
		||||
    ]
 | 
			
		||||
@ -23,7 +23,7 @@ LOGGER = get_logger()
 | 
			
		||||
 | 
			
		||||
VALID_SCHEMA_NAME = re.compile(r"^t_[a-z0-9]{1,61}$")
 | 
			
		||||
 | 
			
		||||
DEFAULT_TOKEN_DURATION = "minutes=30"  # nosec
 | 
			
		||||
DEFAULT_TOKEN_DURATION = "days=1"  # nosec
 | 
			
		||||
DEFAULT_TOKEN_LENGTH = 60
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,7 @@
 | 
			
		||||
from tenant_schemas_celery.scheduler import (
 | 
			
		||||
    TenantAwarePersistentScheduler as BaseTenantAwarePersistentScheduler,
 | 
			
		||||
)
 | 
			
		||||
from tenant_schemas_celery.scheduler import TenantAwareScheduleEntry
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TenantAwarePersistentScheduler(BaseTenantAwarePersistentScheduler):
 | 
			
		||||
@ -11,3 +12,11 @@ class TenantAwarePersistentScheduler(BaseTenantAwarePersistentScheduler):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_queryset(cls):
 | 
			
		||||
        return super().get_queryset().filter(ready=True)
 | 
			
		||||
 | 
			
		||||
    def apply_entry(self, entry: TenantAwareScheduleEntry, producer=None):
 | 
			
		||||
        # https://github.com/maciej-gol/tenant-schemas-celery/blob/master/tenant_schemas_celery/scheduler.py#L85
 | 
			
		||||
        # When (as by default) no tenant schemas are set, the public schema is excluded
 | 
			
		||||
        # so we need to explicitly include it here, otherwise the task is not executed
 | 
			
		||||
        if entry.tenant_schemas is None:
 | 
			
		||||
            entry.tenant_schemas = self.get_queryset().values_list("schema_name", flat=True)
 | 
			
		||||
        return super().apply_entry(entry, producer)
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,7 @@
 | 
			
		||||
    "$schema": "http://json-schema.org/draft-07/schema",
 | 
			
		||||
    "$id": "https://goauthentik.io/blueprints/schema.json",
 | 
			
		||||
    "type": "object",
 | 
			
		||||
    "title": "authentik 2024.2.3 Blueprint schema",
 | 
			
		||||
    "title": "authentik 2024.4.4 Blueprint schema",
 | 
			
		||||
    "required": [
 | 
			
		||||
        "version",
 | 
			
		||||
        "entries"
 | 
			
		||||
@ -4131,6 +4131,10 @@
 | 
			
		||||
                        "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256",
 | 
			
		||||
                        "http://www.w3.org/2001/04/xmldsig-more#rsa-sha384",
 | 
			
		||||
                        "http://www.w3.org/2001/04/xmldsig-more#rsa-sha512",
 | 
			
		||||
                        "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1",
 | 
			
		||||
                        "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256",
 | 
			
		||||
                        "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384",
 | 
			
		||||
                        "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512",
 | 
			
		||||
                        "http://www.w3.org/2000/09/xmldsig#dsa-sha1"
 | 
			
		||||
                    ],
 | 
			
		||||
                    "title": "Signature algorithm"
 | 
			
		||||
@ -4935,6 +4939,10 @@
 | 
			
		||||
                        "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256",
 | 
			
		||||
                        "http://www.w3.org/2001/04/xmldsig-more#rsa-sha384",
 | 
			
		||||
                        "http://www.w3.org/2001/04/xmldsig-more#rsa-sha512",
 | 
			
		||||
                        "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1",
 | 
			
		||||
                        "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256",
 | 
			
		||||
                        "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384",
 | 
			
		||||
                        "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512",
 | 
			
		||||
                        "http://www.w3.org/2000/09/xmldsig#dsa-sha1"
 | 
			
		||||
                    ],
 | 
			
		||||
                    "title": "Signature algorithm"
 | 
			
		||||
 | 
			
		||||
@ -32,7 +32,7 @@ services:
 | 
			
		||||
    volumes:
 | 
			
		||||
      - redis:/data
 | 
			
		||||
  server:
 | 
			
		||||
    image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.2.3}
 | 
			
		||||
    image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.4.4}
 | 
			
		||||
    restart: unless-stopped
 | 
			
		||||
    command: server
 | 
			
		||||
    environment:
 | 
			
		||||
@ -53,7 +53,7 @@ services:
 | 
			
		||||
      - postgresql
 | 
			
		||||
      - redis
 | 
			
		||||
  worker:
 | 
			
		||||
    image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.2.3}
 | 
			
		||||
    image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.4.4}
 | 
			
		||||
    restart: unless-stopped
 | 
			
		||||
    command: worker
 | 
			
		||||
    environment:
 | 
			
		||||
 | 
			
		||||
@ -29,4 +29,4 @@ func UserAgent() string {
 | 
			
		||||
	return fmt.Sprintf("authentik@%s", FullVersion())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const VERSION = "2024.2.3"
 | 
			
		||||
const VERSION = "2024.4.4"
 | 
			
		||||
 | 
			
		||||
@ -54,7 +54,7 @@ function cleanup {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
function prepare_debug {
 | 
			
		||||
    poetry install --no-ansi --no-interaction
 | 
			
		||||
    VIRTUAL_ENV=/ak-root/venv poetry install --no-ansi --no-interaction
 | 
			
		||||
    touch /unittest.xml
 | 
			
		||||
    chown authentik:authentik /unittest.xml
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -117,6 +117,8 @@ def run_migrations():
 | 
			
		||||
        )
 | 
			
		||||
    finally:
 | 
			
		||||
        release_lock(curr)
 | 
			
		||||
        curr.close()
 | 
			
		||||
        conn.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
@ -3,7 +3,6 @@
 | 
			
		||||
import authentik. This is done by the dockerfile."""
 | 
			
		||||
from sys import exit as sysexit
 | 
			
		||||
from time import sleep
 | 
			
		||||
from urllib.parse import quote_plus
 | 
			
		||||
 | 
			
		||||
from psycopg import OperationalError, connect
 | 
			
		||||
from redis import Redis
 | 
			
		||||
@ -35,7 +34,7 @@ def check_postgres():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_redis():
 | 
			
		||||
    url = redis_url(CONFIG.get("redis.db"))
 | 
			
		||||
    url = CONFIG.get("cache.url") or redis_url(CONFIG.get("redis.db"))
 | 
			
		||||
    while True:
 | 
			
		||||
        try:
 | 
			
		||||
            redis = Redis.from_url(url)
 | 
			
		||||
@ -43,10 +42,7 @@ def check_redis():
 | 
			
		||||
            break
 | 
			
		||||
        except RedisError as exc:
 | 
			
		||||
            sleep(1)
 | 
			
		||||
            sanitized_url = url.replace(quote_plus(CONFIG.get("redis.password")), "******")
 | 
			
		||||
            CONFIG.log(
 | 
			
		||||
                "info", f"Redis Connection failed, retrying... ({exc})", redis_url=sanitized_url
 | 
			
		||||
            )
 | 
			
		||||
            CONFIG.log("info", f"Redis Connection failed, retrying... ({exc})")
 | 
			
		||||
    CONFIG.log("info", "Redis Connection successful")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1914
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										1914
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -1,6 +1,6 @@
 | 
			
		||||
[tool.poetry]
 | 
			
		||||
name = "authentik"
 | 
			
		||||
version = "2024.2.3"
 | 
			
		||||
version = "2024.4.4"
 | 
			
		||||
description = ""
 | 
			
		||||
authors = ["authentik Team <hello@goauthentik.io>"]
 | 
			
		||||
 | 
			
		||||
@ -89,6 +89,7 @@ channels = { version = "*", extras = ["daphne"] }
 | 
			
		||||
channels-redis = "*"
 | 
			
		||||
codespell = "*"
 | 
			
		||||
colorama = "*"
 | 
			
		||||
cryptography = "*"
 | 
			
		||||
dacite = "*"
 | 
			
		||||
deepmerge = "*"
 | 
			
		||||
defusedxml = "*"
 | 
			
		||||
@ -101,7 +102,7 @@ django-redis = "*"
 | 
			
		||||
django-storages = { extras = ["s3"], version = "*" }
 | 
			
		||||
# See https://github.com/django-tenants/django-tenants/pull/997
 | 
			
		||||
django-tenants = { git = "https://github.com/rissson/django-tenants.git", branch="authentik-fixes" }
 | 
			
		||||
djangorestframework = "*"
 | 
			
		||||
djangorestframework = "3.14.0"
 | 
			
		||||
djangorestframework-guardian = "*"
 | 
			
		||||
docker = "*"
 | 
			
		||||
drf-spectacular = "*"
 | 
			
		||||
@ -115,17 +116,11 @@ gunicorn = "*"
 | 
			
		||||
jsonpatch = "*"
 | 
			
		||||
kubernetes = "*"
 | 
			
		||||
ldap3 = "*"
 | 
			
		||||
lxml = [
 | 
			
		||||
    # 5.0.0 works with libxml2 2.11.x, which is standard on brew
 | 
			
		||||
    { version = "5.0.0", platform = "darwin" },
 | 
			
		||||
    # 4.9.x works with previous libxml2 versions, which is what we get on linux
 | 
			
		||||
    { version = "4.9.4", platform = "linux" },
 | 
			
		||||
]
 | 
			
		||||
lxml = "*"
 | 
			
		||||
opencontainers = { extras = ["reggie"], version = "*" }
 | 
			
		||||
packaging = "*"
 | 
			
		||||
paramiko = "*"
 | 
			
		||||
psycopg = { extras = ["c"], version = "*" }
 | 
			
		||||
pycryptodome = "*"
 | 
			
		||||
pydantic = "*"
 | 
			
		||||
pydantic-scim = "*"
 | 
			
		||||
pyjwt = "*"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										23
									
								
								schema.yml
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								schema.yml
									
									
									
									
									
								
							@ -1,7 +1,7 @@
 | 
			
		||||
openapi: 3.0.3
 | 
			
		||||
info:
 | 
			
		||||
  title: authentik
 | 
			
		||||
  version: 2024.2.3
 | 
			
		||||
  version: 2024.4.4
 | 
			
		||||
  description: Making authentication simple.
 | 
			
		||||
  contact:
 | 
			
		||||
    email: hello@goauthentik.io
 | 
			
		||||
@ -17051,6 +17051,10 @@ paths:
 | 
			
		||||
          enum:
 | 
			
		||||
          - http://www.w3.org/2000/09/xmldsig#dsa-sha1
 | 
			
		||||
          - http://www.w3.org/2000/09/xmldsig#rsa-sha1
 | 
			
		||||
          - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1
 | 
			
		||||
          - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256
 | 
			
		||||
          - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384
 | 
			
		||||
          - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512
 | 
			
		||||
          - http://www.w3.org/2001/04/xmldsig-more#rsa-sha256
 | 
			
		||||
          - http://www.w3.org/2001/04/xmldsig-more#rsa-sha384
 | 
			
		||||
          - http://www.w3.org/2001/04/xmldsig-more#rsa-sha512
 | 
			
		||||
@ -20910,6 +20914,10 @@ paths:
 | 
			
		||||
          enum:
 | 
			
		||||
          - http://www.w3.org/2000/09/xmldsig#dsa-sha1
 | 
			
		||||
          - http://www.w3.org/2000/09/xmldsig#rsa-sha1
 | 
			
		||||
          - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1
 | 
			
		||||
          - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256
 | 
			
		||||
          - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384
 | 
			
		||||
          - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512
 | 
			
		||||
          - http://www.w3.org/2001/04/xmldsig-more#rsa-sha256
 | 
			
		||||
          - http://www.w3.org/2001/04/xmldsig-more#rsa-sha384
 | 
			
		||||
          - http://www.w3.org/2001/04/xmldsig-more#rsa-sha512
 | 
			
		||||
@ -30450,6 +30458,11 @@ components:
 | 
			
		||||
      - pending_user
 | 
			
		||||
      - pending_user_avatar
 | 
			
		||||
      - type
 | 
			
		||||
    AlgEnum:
 | 
			
		||||
      enum:
 | 
			
		||||
      - rsa
 | 
			
		||||
      - ecdsa
 | 
			
		||||
      type: string
 | 
			
		||||
    App:
 | 
			
		||||
      type: object
 | 
			
		||||
      description: Serialize Application info
 | 
			
		||||
@ -32107,6 +32120,10 @@ components:
 | 
			
		||||
          type: string
 | 
			
		||||
        validity_days:
 | 
			
		||||
          type: integer
 | 
			
		||||
        alg:
 | 
			
		||||
          allOf:
 | 
			
		||||
          - $ref: '#/components/schemas/AlgEnum'
 | 
			
		||||
          default: rsa
 | 
			
		||||
      required:
 | 
			
		||||
      - common_name
 | 
			
		||||
      - validity_days
 | 
			
		||||
@ -43658,6 +43675,10 @@ components:
 | 
			
		||||
      - http://www.w3.org/2001/04/xmldsig-more#rsa-sha256
 | 
			
		||||
      - http://www.w3.org/2001/04/xmldsig-more#rsa-sha384
 | 
			
		||||
      - http://www.w3.org/2001/04/xmldsig-more#rsa-sha512
 | 
			
		||||
      - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1
 | 
			
		||||
      - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256
 | 
			
		||||
      - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384
 | 
			
		||||
      - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512
 | 
			
		||||
      - http://www.w3.org/2000/09/xmldsig#dsa-sha1
 | 
			
		||||
      type: string
 | 
			
		||||
    Source:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1934
									
								
								web/package-lock.json
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										1934
									
								
								web/package-lock.json
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -38,7 +38,7 @@
 | 
			
		||||
        "@codemirror/theme-one-dark": "^6.1.2",
 | 
			
		||||
        "@formatjs/intl-listformat": "^7.5.5",
 | 
			
		||||
        "@fortawesome/fontawesome-free": "^6.5.2",
 | 
			
		||||
        "@goauthentik/api": "^2024.2.3-1713441634",
 | 
			
		||||
        "@goauthentik/api": "^2024.4.1-1714655911",
 | 
			
		||||
        "@lit-labs/task": "^3.1.0",
 | 
			
		||||
        "@lit/context": "^1.1.1",
 | 
			
		||||
        "@lit/localize": "^0.12.1",
 | 
			
		||||
 | 
			
		||||
@ -29,5 +29,9 @@ export const signatureAlgorithmOptions = toOptions([
 | 
			
		||||
    ["RSA-SHA256", SignatureAlgorithmEnum._200104XmldsigMorersaSha256, true],
 | 
			
		||||
    ["RSA-SHA384", SignatureAlgorithmEnum._200104XmldsigMorersaSha384],
 | 
			
		||||
    ["RSA-SHA512", SignatureAlgorithmEnum._200104XmldsigMorersaSha512],
 | 
			
		||||
    ["ECDSA-SHA1", SignatureAlgorithmEnum._200104XmldsigMoreecdsaSha1],
 | 
			
		||||
    ["ECDSA-SHA256", SignatureAlgorithmEnum._200104XmldsigMoreecdsaSha256],
 | 
			
		||||
    ["ECDSA-SHA384", SignatureAlgorithmEnum._200104XmldsigMoreecdsaSha384],
 | 
			
		||||
    ["ECDSA-SHA512", SignatureAlgorithmEnum._200104XmldsigMoreecdsaSha512],
 | 
			
		||||
    ["DSA-SHA1", SignatureAlgorithmEnum._200009XmldsigdsaSha1],
 | 
			
		||||
]);
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,12 @@ import { msg } from "@lit/localize";
 | 
			
		||||
import { TemplateResult, html } from "lit";
 | 
			
		||||
import { customElement } from "lit/decorators.js";
 | 
			
		||||
 | 
			
		||||
import { CertificateGenerationRequest, CertificateKeyPair, CryptoApi } from "@goauthentik/api";
 | 
			
		||||
import {
 | 
			
		||||
    AlgEnum,
 | 
			
		||||
    CertificateGenerationRequest,
 | 
			
		||||
    CertificateKeyPair,
 | 
			
		||||
    CryptoApi,
 | 
			
		||||
} from "@goauthentik/api";
 | 
			
		||||
 | 
			
		||||
@customElement("ak-crypto-certificate-generate-form")
 | 
			
		||||
export class CertificateKeyPairForm extends Form<CertificateGenerationRequest> {
 | 
			
		||||
@ -40,6 +45,29 @@ export class CertificateKeyPairForm extends Form<CertificateGenerationRequest> {
 | 
			
		||||
                ?required=${true}
 | 
			
		||||
            >
 | 
			
		||||
                <input class="pf-c-form-control" type="number" value="365" />
 | 
			
		||||
            </ak-form-element-horizontal>`;
 | 
			
		||||
            </ak-form-element-horizontal>
 | 
			
		||||
            <ak-form-element-horizontal
 | 
			
		||||
                label=${msg("Private key Algorithm")}
 | 
			
		||||
                ?required=${true}
 | 
			
		||||
                name="alg"
 | 
			
		||||
            >
 | 
			
		||||
                <ak-radio
 | 
			
		||||
                    .options=${[
 | 
			
		||||
                        {
 | 
			
		||||
                            label: msg("RSA"),
 | 
			
		||||
                            value: AlgEnum.Rsa,
 | 
			
		||||
                            default: true,
 | 
			
		||||
                        },
 | 
			
		||||
                        {
 | 
			
		||||
                            label: msg("ECDSA"),
 | 
			
		||||
                            value: AlgEnum.Ecdsa,
 | 
			
		||||
                        },
 | 
			
		||||
                    ]}
 | 
			
		||||
                >
 | 
			
		||||
                </ak-radio>
 | 
			
		||||
                <p class="pf-c-form__helper-text">
 | 
			
		||||
                    ${msg("Algorithm used to generate the private key.")}
 | 
			
		||||
                </p>
 | 
			
		||||
            </ak-form-element-horizontal> `;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -97,7 +97,7 @@ export class EventListPage extends TablePage<Event> {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    renderExpanded(item: Event): TemplateResult {
 | 
			
		||||
        return html` <td role="cell" colspan="3">
 | 
			
		||||
        return html` <td role="cell" colspan="5">
 | 
			
		||||
                <div class="pf-c-table__expandable-row-content">
 | 
			
		||||
                    <ak-event-info .event=${item as EventWithContext}></ak-event-info>
 | 
			
		||||
                </div>
 | 
			
		||||
 | 
			
		||||
@ -214,28 +214,23 @@ export class IdentificationStageForm extends BaseStageForm<IdentificationStage>
 | 
			
		||||
                        name="sources"
 | 
			
		||||
                    >
 | 
			
		||||
                        <select class="pf-c-form-control" multiple>
 | 
			
		||||
                            ${this.sources?.results.map((source) => {
 | 
			
		||||
                                let selected = Array.from(this.instance?.sources || []).some(
 | 
			
		||||
                                    (su) => {
 | 
			
		||||
                                        return su == source.pk;
 | 
			
		||||
                                    },
 | 
			
		||||
                                );
 | 
			
		||||
                                // Creating a new instance, auto-select built-in source
 | 
			
		||||
                                // Only when no other sources exist
 | 
			
		||||
                                if (
 | 
			
		||||
                                    !this.instance &&
 | 
			
		||||
                                    source.component === "" &&
 | 
			
		||||
                                    (this.sources?.results || []).length < 2
 | 
			
		||||
                                ) {
 | 
			
		||||
                                    selected = true;
 | 
			
		||||
                                }
 | 
			
		||||
                                return html`<option
 | 
			
		||||
                                    value=${ifDefined(source.pk)}
 | 
			
		||||
                                    ?selected=${selected}
 | 
			
		||||
                                >
 | 
			
		||||
                                    ${source.name}
 | 
			
		||||
                                </option>`;
 | 
			
		||||
                            })}
 | 
			
		||||
                            ${this.sources?.results
 | 
			
		||||
                                .filter((source) => {
 | 
			
		||||
                                    return source.component !== "";
 | 
			
		||||
                                })
 | 
			
		||||
                                .map((source) => {
 | 
			
		||||
                                    const selected = Array.from(this.instance?.sources || []).some(
 | 
			
		||||
                                        (su) => {
 | 
			
		||||
                                            return su == source.pk;
 | 
			
		||||
                                        },
 | 
			
		||||
                                    );
 | 
			
		||||
                                    return html`<option
 | 
			
		||||
                                        value=${ifDefined(source.pk)}
 | 
			
		||||
                                        ?selected=${selected}
 | 
			
		||||
                                    >
 | 
			
		||||
                                        ${source.name}
 | 
			
		||||
                                    </option>`;
 | 
			
		||||
                                })}
 | 
			
		||||
                        </select>
 | 
			
		||||
                        <p class="pf-c-form__helper-text">
 | 
			
		||||
                            ${msg(
 | 
			
		||||
 | 
			
		||||
@ -128,6 +128,14 @@ export class UserForm extends ModelForm<User, number> {
 | 
			
		||||
                                "Service accounts should be used for machine-to-machine authentication or other automations.",
 | 
			
		||||
                            )}`,
 | 
			
		||||
                        },
 | 
			
		||||
                        {
 | 
			
		||||
                            label: "Internal Service account",
 | 
			
		||||
                            value: UserTypeEnum.InternalServiceAccount,
 | 
			
		||||
                            disabled: true,
 | 
			
		||||
                            description: html`${msg(
 | 
			
		||||
                                "Internal Service accounts are created and managed by authentik and cannot be created manually.",
 | 
			
		||||
                            )}`,
 | 
			
		||||
                        },
 | 
			
		||||
                    ]}
 | 
			
		||||
                    .value=${this.instance?.type}
 | 
			
		||||
                >
 | 
			
		||||
 | 
			
		||||
@ -3,7 +3,7 @@ export const SUCCESS_CLASS = "pf-m-success";
 | 
			
		||||
export const ERROR_CLASS = "pf-m-danger";
 | 
			
		||||
export const PROGRESS_CLASS = "pf-m-in-progress";
 | 
			
		||||
export const CURRENT_CLASS = "pf-m-current";
 | 
			
		||||
export const VERSION = "2024.2.3";
 | 
			
		||||
export const VERSION = "2024.4.4";
 | 
			
		||||
export const TITLE_DEFAULT = "authentik";
 | 
			
		||||
export const ROUTE_SEPARATOR = ";";
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -187,6 +187,9 @@ input[type="date"]::-webkit-calendar-picker-indicator {
 | 
			
		||||
.pf-c-select__menu-item.pf-m-focus {
 | 
			
		||||
    --pf-c-select__menu-item--focus--BackgroundColor: var(--ak-dark-background-light-ish);
 | 
			
		||||
}
 | 
			
		||||
.pf-c-button:disabled {
 | 
			
		||||
    color: var(--ak-dark-background-lighter);
 | 
			
		||||
}
 | 
			
		||||
.pf-c-button.pf-m-plain:hover {
 | 
			
		||||
    color: var(--ak-dark-foreground);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,7 @@ export function me(): Promise<SessionUser> {
 | 
			
		||||
                if (!user.user.settings || !("locale" in user.user.settings)) {
 | 
			
		||||
                    return user;
 | 
			
		||||
                }
 | 
			
		||||
                const locale = user.user.settings.locale;
 | 
			
		||||
                const locale: string | undefined = user.user.settings.locale;
 | 
			
		||||
                if (locale && locale !== "") {
 | 
			
		||||
                    console.debug(
 | 
			
		||||
                        `authentik/locale: Activating user's configured locale '${locale}'`,
 | 
			
		||||
 | 
			
		||||
@ -111,6 +111,21 @@ export function dateTimeLocal(date: Date): string {
 | 
			
		||||
    return `${parts[0]}:${parts[1]}`;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export function dateToUTC(date: Date): Date {
 | 
			
		||||
    // Sigh...so our API is UTC/can take TZ info in the ISO format as it should.
 | 
			
		||||
    // datetime-local fields (which is almost the only date-time input we use)
 | 
			
		||||
    // can return its value as a UTC timestamp...however the generated API client
 | 
			
		||||
    // _requires_ a Date object, only to then convert it to an ISO string anyways
 | 
			
		||||
    // JS Dates don't include timezone info in the ISO string, so that just sends
 | 
			
		||||
    // the local time as UTC...which is wrong
 | 
			
		||||
    // Instead we have to do this, convert the given date to a UTC timestamp,
 | 
			
		||||
    // then subtract the timezone offset to create an "invalid" date (correct time&date)
 | 
			
		||||
    // but it still "thinks" it's in local TZ
 | 
			
		||||
    const timestamp = date.getTime();
 | 
			
		||||
    const offset = -1 * (new Date().getTimezoneOffset() * 60000);
 | 
			
		||||
    return new Date(timestamp - offset);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Lit is extremely well-typed with regard to CSS, and Storybook's `build` does not currently have a
 | 
			
		||||
// coherent way of importing CSS-as-text into CSSStyleSheet. It works well when Storybook is running
 | 
			
		||||
// in `dev,` but in `build` it fails. Storied components will have to map their textual CSS imports
 | 
			
		||||
 | 
			
		||||
@ -18,6 +18,7 @@ import PFDescriptionList from "@patternfly/patternfly/components/DescriptionList
 | 
			
		||||
import PFList from "@patternfly/patternfly/components/List/list.css";
 | 
			
		||||
import PFTable from "@patternfly/patternfly/components/Table/table.css";
 | 
			
		||||
import PFFlex from "@patternfly/patternfly/layouts/Flex/flex.css";
 | 
			
		||||
import PFSplit from "@patternfly/patternfly/layouts/Split/split.css";
 | 
			
		||||
import PFBase from "@patternfly/patternfly/patternfly-base.css";
 | 
			
		||||
 | 
			
		||||
import { EventActions, FlowsApi } from "@goauthentik/api";
 | 
			
		||||
@ -81,6 +82,7 @@ export class EventInfo extends AKElement {
 | 
			
		||||
            PFCard,
 | 
			
		||||
            PFTable,
 | 
			
		||||
            PFList,
 | 
			
		||||
            PFSplit,
 | 
			
		||||
            PFDescriptionList,
 | 
			
		||||
            css`
 | 
			
		||||
                code {
 | 
			
		||||
@ -246,11 +248,17 @@ export class EventInfo extends AKElement {
 | 
			
		||||
 | 
			
		||||
    renderModelChanged() {
 | 
			
		||||
        const diff = this.event.context.diff as unknown as {
 | 
			
		||||
            [key: string]: { new_value: unknown; previous_value: unknown };
 | 
			
		||||
            [key: string]: {
 | 
			
		||||
                new_value: unknown;
 | 
			
		||||
                previous_value: unknown;
 | 
			
		||||
                add?: unknown[];
 | 
			
		||||
                remove?: unknown[];
 | 
			
		||||
                clear?: boolean;
 | 
			
		||||
            };
 | 
			
		||||
        };
 | 
			
		||||
        let diffBody = html``;
 | 
			
		||||
        if (diff) {
 | 
			
		||||
            diffBody = html`<div class="pf-l-flex__item">
 | 
			
		||||
            diffBody = html`<div class="pf-l-split__item pf-m-fill">
 | 
			
		||||
                    <div class="pf-c-card__title">${msg("Changes made:")}</div>
 | 
			
		||||
                    <table class="pf-c-table pf-m-compact pf-m-grid-md" role="grid">
 | 
			
		||||
                        <thead>
 | 
			
		||||
@ -262,16 +270,36 @@ export class EventInfo extends AKElement {
 | 
			
		||||
                        </thead>
 | 
			
		||||
                        <tbody role="rowgroup">
 | 
			
		||||
                            ${Object.keys(diff).map((key) => {
 | 
			
		||||
                                const value = diff[key];
 | 
			
		||||
                                const previousCol = value.previous_value
 | 
			
		||||
                                    ? JSON.stringify(value.previous_value, null, 4)
 | 
			
		||||
                                    : msg("-");
 | 
			
		||||
                                let newCol = html``;
 | 
			
		||||
                                if (value.add || value.remove) {
 | 
			
		||||
                                    newCol = html`<ul class="pf-c-list">
 | 
			
		||||
                                        ${(value.add || value.remove)?.map((item) => {
 | 
			
		||||
                                            let itemLabel = "";
 | 
			
		||||
                                            if (value.add) {
 | 
			
		||||
                                                itemLabel = msg(str`Added ID ${item}`);
 | 
			
		||||
                                            } else if (value.remove) {
 | 
			
		||||
                                                itemLabel = msg(str`Removed ID ${item}`);
 | 
			
		||||
                                            }
 | 
			
		||||
                                            return html`<li>${itemLabel}</li>`;
 | 
			
		||||
                                        })}
 | 
			
		||||
                                    </ul>`;
 | 
			
		||||
                                } else if (value.clear) {
 | 
			
		||||
                                    newCol = html`${msg("Cleared")}`;
 | 
			
		||||
                                } else {
 | 
			
		||||
                                    newCol = html`<pre>
 | 
			
		||||
${JSON.stringify(value.new_value, null, 4)}</pre
 | 
			
		||||
                                    >`;
 | 
			
		||||
                                }
 | 
			
		||||
                                return html` <tr role="row">
 | 
			
		||||
                                    <td role="cell"><pre>${key}</pre></td>
 | 
			
		||||
                                    <td role="cell">
 | 
			
		||||
                                        <pre>
 | 
			
		||||
${JSON.stringify(diff[key].previous_value, null, 4)}</pre
 | 
			
		||||
                                        >
 | 
			
		||||
                                    </td>
 | 
			
		||||
                                    <td role="cell">
 | 
			
		||||
                                        <pre>${JSON.stringify(diff[key].new_value, null, 4)}</pre>
 | 
			
		||||
                                        <pre>${previousCol}</pre>
 | 
			
		||||
                                    </td>
 | 
			
		||||
                                    <td role="cell">${newCol}</td>
 | 
			
		||||
                                </tr>`;
 | 
			
		||||
                            })}
 | 
			
		||||
                        </tbody>
 | 
			
		||||
@ -280,8 +308,8 @@ ${JSON.stringify(diff[key].previous_value, null, 4)}</pre
 | 
			
		||||
                </div>`;
 | 
			
		||||
        }
 | 
			
		||||
        return html`
 | 
			
		||||
            <div class="pf-l-flex">
 | 
			
		||||
                <div class="pf-l-flex__item">
 | 
			
		||||
            <div class="pf-l-split">
 | 
			
		||||
                <div class="pf-l-split__item pf-m-fill">
 | 
			
		||||
                    <div class="pf-c-card__title">${msg("Affected model:")}</div>
 | 
			
		||||
                    <div class="pf-c-card__body">
 | 
			
		||||
                        ${this.getModelInfo(this.event.context?.model as EventModel)}
 | 
			
		||||
 | 
			
		||||
@ -87,7 +87,7 @@ export class Markdown extends AKElement {
 | 
			
		||||
            const parsedContent = matter(this.md);
 | 
			
		||||
            const parsedHTML = this.converter.makeHtml(parsedContent.content);
 | 
			
		||||
            const replacers = [...this.defaultReplacers, ...this.replacers];
 | 
			
		||||
            this.docTitle = parsedContent.data["title"] ?? "";
 | 
			
		||||
            this.docTitle = parsedContent?.data?.title ?? "";
 | 
			
		||||
            this.docHtml = replacers.reduce(
 | 
			
		||||
                (html, replacer) => replacer(html, { path: this.meta }),
 | 
			
		||||
                parsedHTML,
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,7 @@ import { WithBrandConfig } from "@goauthentik/elements/Interface/brandProvider";
 | 
			
		||||
import "@patternfly/elements/pf-tooltip/pf-tooltip.js";
 | 
			
		||||
 | 
			
		||||
import { msg } from "@lit/localize";
 | 
			
		||||
import { CSSResult, PropertyValues, TemplateResult, css, html } from "lit";
 | 
			
		||||
import { CSSResult, TemplateResult, css, html } from "lit";
 | 
			
		||||
import { customElement, property } from "lit/decorators.js";
 | 
			
		||||
 | 
			
		||||
import PFButton from "@patternfly/patternfly/components/Button/button.css";
 | 
			
		||||
@ -107,21 +107,23 @@ export class PageHeader extends WithBrandConfig(AKElement) {
 | 
			
		||||
        });
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    setTitle(value: string) {
 | 
			
		||||
    setTitle(header?: string) {
 | 
			
		||||
        const currentIf = currentInterface();
 | 
			
		||||
        const title = this.brand?.brandingTitle || TITLE_DEFAULT;
 | 
			
		||||
        document.title =
 | 
			
		||||
            currentIf === "admin"
 | 
			
		||||
                ? `${msg("Admin")} - ${title}`
 | 
			
		||||
                : value !== ""
 | 
			
		||||
                  ? `${value} - ${title}`
 | 
			
		||||
                  : title;
 | 
			
		||||
        let title = this.brand?.brandingTitle || TITLE_DEFAULT;
 | 
			
		||||
        if (currentIf === "admin") {
 | 
			
		||||
            title = `${msg("Admin")} - ${title}`;
 | 
			
		||||
        }
 | 
			
		||||
        // Prepend the header to the title
 | 
			
		||||
        if (header !== undefined && header !== "") {
 | 
			
		||||
            title = `${header} - ${title}`;
 | 
			
		||||
        }
 | 
			
		||||
        document.title = title;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    willUpdate(changedProperties: PropertyValues<this>) {
 | 
			
		||||
        if (changedProperties.has("header") && this.header) {
 | 
			
		||||
            this.setTitle(this.header);
 | 
			
		||||
        }
 | 
			
		||||
    willUpdate() {
 | 
			
		||||
        // Always update title, even if there's no header value set,
 | 
			
		||||
        // as in that case we still need to return to the generic title
 | 
			
		||||
        this.setTitle(this.header);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    renderIcon(): TemplateResult {
 | 
			
		||||
 | 
			
		||||
@ -2,8 +2,9 @@ import { AKElement } from "@goauthentik/elements/Base";
 | 
			
		||||
import { CustomEmitterElement } from "@goauthentik/elements/utils/eventEmitter";
 | 
			
		||||
 | 
			
		||||
import { msg } from "@lit/localize";
 | 
			
		||||
import { PropertyValues } from "@lit/reactive-element/reactive-element";
 | 
			
		||||
import { TemplateResult, css, html } from "lit";
 | 
			
		||||
import { customElement, property, queryAll } from "lit/decorators.js";
 | 
			
		||||
import { customElement, property, queryAll, state } from "lit/decorators.js";
 | 
			
		||||
import { map } from "lit/directives/map.js";
 | 
			
		||||
 | 
			
		||||
import PFCheck from "@patternfly/patternfly/components/Check/check.css";
 | 
			
		||||
@ -112,10 +113,14 @@ export class CheckboxGroup extends AkElementWithCustomEvents {
 | 
			
		||||
    @queryAll('input[type="checkbox"]')
 | 
			
		||||
    checkboxes!: NodeListOf<HTMLInputElement>;
 | 
			
		||||
 | 
			
		||||
    internals?: ElementInternals;
 | 
			
		||||
    @state()
 | 
			
		||||
    values: string[] = [];
 | 
			
		||||
 | 
			
		||||
    get json() {
 | 
			
		||||
        return this.value;
 | 
			
		||||
    internals?: ElementInternals;
 | 
			
		||||
    doneFirstUpdate = false;
 | 
			
		||||
 | 
			
		||||
    json() {
 | 
			
		||||
        return this.values;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private get formValue() {
 | 
			
		||||
@ -124,7 +129,7 @@ export class CheckboxGroup extends AkElementWithCustomEvents {
 | 
			
		||||
        }
 | 
			
		||||
        const name = this.name;
 | 
			
		||||
        const entries = new FormData();
 | 
			
		||||
        this.value.forEach((v) => entries.append(name, v));
 | 
			
		||||
        this.values.forEach((v) => entries.append(name, v));
 | 
			
		||||
        return entries;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -136,14 +141,14 @@ export class CheckboxGroup extends AkElementWithCustomEvents {
 | 
			
		||||
 | 
			
		||||
    onClick(ev: Event) {
 | 
			
		||||
        ev.stopPropagation();
 | 
			
		||||
        this.value = Array.from(this.checkboxes)
 | 
			
		||||
        this.values = Array.from(this.checkboxes)
 | 
			
		||||
            .filter((checkbox) => checkbox.checked)
 | 
			
		||||
            .map((checkbox) => checkbox.name);
 | 
			
		||||
        this.dispatchCustomEvent("change", this.value);
 | 
			
		||||
        this.dispatchCustomEvent("input", this.value);
 | 
			
		||||
        this.dispatchCustomEvent("change", this.values);
 | 
			
		||||
        this.dispatchCustomEvent("input", this.values);
 | 
			
		||||
        if (this.internals) {
 | 
			
		||||
            this.internals.setValidity({});
 | 
			
		||||
            if (this.required && this.value.length === 0) {
 | 
			
		||||
            if (this.required && this.values.length === 0) {
 | 
			
		||||
                this.internals.setValidity(
 | 
			
		||||
                    {
 | 
			
		||||
                        valueMissing: true,
 | 
			
		||||
@ -154,6 +159,16 @@ export class CheckboxGroup extends AkElementWithCustomEvents {
 | 
			
		||||
            }
 | 
			
		||||
            this.internals.setFormValue(this.formValue);
 | 
			
		||||
        }
 | 
			
		||||
        // Doing a write-back so anyone examining the checkbox.value field will get something
 | 
			
		||||
        // meaningful. Doesn't do anything for anyone, usually, but it's nice to have.
 | 
			
		||||
        this.value = this.values;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    willUpdate(changed: PropertyValues<this>) {
 | 
			
		||||
        if (changed.has("value") && !this.doneFirstUpdate) {
 | 
			
		||||
            this.doneFirstUpdate = true;
 | 
			
		||||
            this.values = this.value;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    connectedCallback() {
 | 
			
		||||
@ -183,7 +198,7 @@ export class CheckboxGroup extends AkElementWithCustomEvents {
 | 
			
		||||
 | 
			
		||||
    render() {
 | 
			
		||||
        const renderOne = ([name, label]: CheckboxPr) => {
 | 
			
		||||
            const selected = this.value.includes(name);
 | 
			
		||||
            const selected = this.values.includes(name);
 | 
			
		||||
            const blockFwd = (e: Event) => {
 | 
			
		||||
                e.stopImmediatePropagation();
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
@ -53,6 +53,9 @@ export class AkDualSelectProvider extends CustomListenerElement(AKElement) {
 | 
			
		||||
 | 
			
		||||
    private isLoading = false;
 | 
			
		||||
 | 
			
		||||
    private doneFirstUpdate = false;
 | 
			
		||||
    private internalSelected: DualSelectPair[] = [];
 | 
			
		||||
 | 
			
		||||
    private pagination?: Pagination;
 | 
			
		||||
 | 
			
		||||
    constructor() {
 | 
			
		||||
@ -69,6 +72,11 @@ export class AkDualSelectProvider extends CustomListenerElement(AKElement) {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    willUpdate(changedProperties: PropertyValues<this>) {
 | 
			
		||||
        if (changedProperties.has("selected") && !this.doneFirstUpdate) {
 | 
			
		||||
            this.doneFirstUpdate = true;
 | 
			
		||||
            this.internalSelected = this.selected;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (changedProperties.has("searchDelay")) {
 | 
			
		||||
            this.doSearch = debounce(
 | 
			
		||||
                AkDualSelectProvider.prototype.doSearch.bind(this),
 | 
			
		||||
@ -105,7 +113,8 @@ export class AkDualSelectProvider extends CustomListenerElement(AKElement) {
 | 
			
		||||
        if (!(event instanceof CustomEvent)) {
 | 
			
		||||
            throw new Error(`Expecting a CustomEvent for change, received ${event} instead`);
 | 
			
		||||
        }
 | 
			
		||||
        this.selected = event.detail.value;
 | 
			
		||||
        this.internalSelected = event.detail.value;
 | 
			
		||||
        this.selected = this.internalSelected;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    onSearch(event: Event) {
 | 
			
		||||
@ -124,12 +133,16 @@ export class AkDualSelectProvider extends CustomListenerElement(AKElement) {
 | 
			
		||||
        return this.dualSelector.value!.selected.map(([k, _]) => k);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    json() {
 | 
			
		||||
        return this.value;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    render() {
 | 
			
		||||
        return html`<ak-dual-select
 | 
			
		||||
            ${ref(this.dualSelector)}
 | 
			
		||||
            .options=${this.options}
 | 
			
		||||
            .pages=${this.pagination}
 | 
			
		||||
            .selected=${this.selected}
 | 
			
		||||
            .selected=${this.internalSelected}
 | 
			
		||||
            available-label=${this.availableLabel}
 | 
			
		||||
            selected-label=${this.selectedLabel}
 | 
			
		||||
        ></ak-dual-select>`;
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,7 @@ import { EVENT_LOCALE_CHANGE, EVENT_LOCALE_REQUEST } from "@goauthentik/common/c
 | 
			
		||||
import { customEvent } from "@goauthentik/elements/utils/customEvents";
 | 
			
		||||
 | 
			
		||||
import { LitElement, html } from "lit";
 | 
			
		||||
import { customElement, property, state } from "lit/decorators.js";
 | 
			
		||||
import { customElement, property } from "lit/decorators.js";
 | 
			
		||||
 | 
			
		||||
import { WithBrandConfig } from "../Interface/brandProvider";
 | 
			
		||||
import { initializeLocalization } from "./configureLocale";
 | 
			
		||||
@ -38,9 +38,6 @@ export class LocaleContext extends LocaleContextBase {
 | 
			
		||||
 | 
			
		||||
    setLocale: LocaleSetter;
 | 
			
		||||
 | 
			
		||||
    @state()
 | 
			
		||||
    userLocale = "";
 | 
			
		||||
 | 
			
		||||
    constructor(code = DEFAULT_LOCALE) {
 | 
			
		||||
        super();
 | 
			
		||||
        this.notifyApplication = this.notifyApplication.bind(this);
 | 
			
		||||
@ -59,30 +56,22 @@ export class LocaleContext extends LocaleContextBase {
 | 
			
		||||
 | 
			
		||||
    connectedCallback() {
 | 
			
		||||
        super.connectedCallback();
 | 
			
		||||
        // Commenting out until we can come up with a better way of separating the
 | 
			
		||||
        // "request user identity" with the session expiration heartbeat.
 | 
			
		||||
        /*
 | 
			
		||||
            new CoreApi(DEFAULT_CONFIG)
 | 
			
		||||
                .coreUsersMeRetrieve()
 | 
			
		||||
                .then((user) => (this.userLocale = user?.user?.settings?.locale ?? ""))
 | 
			
		||||
                .catch(() => {});
 | 
			
		||||
        */
 | 
			
		||||
        this.updateLocale();
 | 
			
		||||
        window.addEventListener(EVENT_LOCALE_REQUEST, this.updateLocaleHandler);
 | 
			
		||||
        window.addEventListener(EVENT_LOCALE_REQUEST, this.updateLocaleHandler as EventListener);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    disconnectedCallback() {
 | 
			
		||||
        window.removeEventListener(EVENT_LOCALE_REQUEST, this.updateLocaleHandler);
 | 
			
		||||
        window.removeEventListener(EVENT_LOCALE_REQUEST, this.updateLocaleHandler as EventListener);
 | 
			
		||||
        super.disconnectedCallback();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    updateLocaleHandler(_ev: Event) {
 | 
			
		||||
    updateLocaleHandler(ev: CustomEvent<{ locale: string }>) {
 | 
			
		||||
        console.debug("authentik/locale: Locale update request received.");
 | 
			
		||||
        this.updateLocale();
 | 
			
		||||
        this.updateLocale(ev.detail.locale);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    updateLocale() {
 | 
			
		||||
        const localeRequest = autoDetectLanguage(this.userLocale, this.brand?.defaultLocale);
 | 
			
		||||
    updateLocale(requestedLocale: string | undefined = undefined) {
 | 
			
		||||
        const localeRequest = autoDetectLanguage(requestedLocale, this.brand?.defaultLocale);
 | 
			
		||||
        const locale = getBestMatchLocale(localeRequest);
 | 
			
		||||
        if (!locale) {
 | 
			
		||||
            console.warn(`authentik/locale: failed to find locale for code ${localeRequest}`);
 | 
			
		||||
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user