Compare commits
	
		
			50 Commits
		
	
	
		
			version/20
			...
			version-20
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 1a21479b0d | |||
| 38154f72e0 | |||
| 19318d4c00 | |||
| be3d7c0666 | |||
| 5afceaa55f | |||
| 72dc27f1c9 | |||
| b5ffd16861 | |||
| 8af754e88c | |||
| ade1f08c89 | |||
| 9240fa1037 | |||
| 1f5953b5b7 | |||
| 5befccc1fd | |||
| ff193d809a | |||
| 23bbb6e5ef | |||
| 225d02d02d | |||
| 90fe1eda66 | |||
| 35ba88a203 | |||
| 8414a9dcad | |||
| 1d626f5b57 | |||
| 508dd0ac64 | |||
| f4b82a8b09 | |||
| 2900f01976 | |||
| 0f6ece5eb7 | |||
| b9936fe532 | |||
| d0b3cc5916 | |||
| e034f5e5dc | |||
| 9d6816bbc8 | |||
| 82d4ea9e8a | |||
| c8a804f2a7 | |||
| ca70c963e5 | |||
| 4c89d4a4a4 | |||
| 8a47acac3a | |||
| 4a3b22491c | |||
| f991d656c7 | |||
| e86aa11131 | |||
| 03725ae086 | |||
| f2a37e8c7c | |||
| e935690b1b | |||
| 02709e4ede | |||
| f78adab9d1 | |||
| 61f3a72fd9 | |||
| 541becfe30 | |||
| 11ff7955f7 | |||
| afa4234036 | |||
| ca22a4deaf | |||
| 7b7a3d34ec | |||
| b1ca579397 | |||
| c8072579c8 | |||
| 378a701fb9 | |||
| bba793d94c | 
| @ -1,5 +1,5 @@ | |||||||
| [bumpversion] | [bumpversion] | ||||||
| current_version = 2024.2.3 | current_version = 2024.4.4 | ||||||
| tag = True | tag = True | ||||||
| commit = True | commit = True | ||||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<rc_t>[a-zA-Z-]+)(?P<rc_n>[1-9]\\d*))? | ||||||
| @ -21,6 +21,8 @@ optional_value = final | |||||||
|  |  | ||||||
| [bumpversion:file:schema.yml] | [bumpversion:file:schema.yml] | ||||||
|  |  | ||||||
|  | [bumpversion:file:blueprints/schema.json] | ||||||
|  |  | ||||||
| [bumpversion:file:authentik/__init__.py] | [bumpversion:file:authentik/__init__.py] | ||||||
|  |  | ||||||
| [bumpversion:file:internal/constants/constants.go] | [bumpversion:file:internal/constants/constants.go] | ||||||
|  | |||||||
| @ -12,7 +12,7 @@ should_build = str(os.environ.get("DOCKER_USERNAME", None) is not None).lower() | |||||||
| branch_name = os.environ["GITHUB_REF"] | branch_name = os.environ["GITHUB_REF"] | ||||||
| if os.environ.get("GITHUB_HEAD_REF", "") != "": | if os.environ.get("GITHUB_HEAD_REF", "") != "": | ||||||
|     branch_name = os.environ["GITHUB_HEAD_REF"] |     branch_name = os.environ["GITHUB_HEAD_REF"] | ||||||
| safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-") | safe_branch_name = branch_name.replace("refs/heads/", "").replace("/", "-").replace("'", "-") | ||||||
|  |  | ||||||
| image_names = os.getenv("IMAGE_NAME").split(",") | image_names = os.getenv("IMAGE_NAME").split(",") | ||||||
| image_arch = os.getenv("IMAGE_ARCH") or None | image_arch = os.getenv("IMAGE_ARCH") or None | ||||||
|  | |||||||
							
								
								
									
										8
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/ci-web.yml
									
									
									
									
										vendored
									
									
								
							| @ -34,6 +34,13 @@ jobs: | |||||||
|       - name: Eslint |       - name: Eslint | ||||||
|         working-directory: ${{ matrix.project }}/ |         working-directory: ${{ matrix.project }}/ | ||||||
|         run: npm run lint |         run: npm run lint | ||||||
|  |   lint-lockfile: | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |     steps: | ||||||
|  |       - uses: actions/checkout@v4 | ||||||
|  |       - working-directory: web/ | ||||||
|  |         run: | | ||||||
|  |           [ -z "$(jq -r '.packages | to_entries[] | select((.key | startswith("node_modules")) and (.value | has("resolved") | not)) | .key' < package-lock.json)" ] | ||||||
|   lint-build: |   lint-build: | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
| @ -95,6 +102,7 @@ jobs: | |||||||
|         run: npm run lit-analyse |         run: npm run lit-analyse | ||||||
|   ci-web-mark: |   ci-web-mark: | ||||||
|     needs: |     needs: | ||||||
|  |       - lint-lockfile | ||||||
|       - lint-eslint |       - lint-eslint | ||||||
|       - lint-prettier |       - lint-prettier | ||||||
|       - lint-lit-analyse |       - lint-lit-analyse | ||||||
|  | |||||||
							
								
								
									
										8
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/ci-website.yml
									
									
									
									
										vendored
									
									
								
							| @ -12,6 +12,13 @@ on: | |||||||
|       - version-* |       - version-* | ||||||
|  |  | ||||||
| jobs: | jobs: | ||||||
|  |   lint-lockfile: | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |     steps: | ||||||
|  |       - uses: actions/checkout@v4 | ||||||
|  |       - working-directory: website/ | ||||||
|  |         run: | | ||||||
|  |           [ -z "$(jq -r '.packages | to_entries[] | select((.key | startswith("node_modules")) and (.value | has("resolved") | not)) | .key' < package-lock.json)" ] | ||||||
|   lint-prettier: |   lint-prettier: | ||||||
|     runs-on: ubuntu-latest |     runs-on: ubuntu-latest | ||||||
|     steps: |     steps: | ||||||
| @ -62,6 +69,7 @@ jobs: | |||||||
|         run: npm run ${{ matrix.job }} |         run: npm run ${{ matrix.job }} | ||||||
|   ci-website-mark: |   ci-website-mark: | ||||||
|     needs: |     needs: | ||||||
|  |       - lint-lockfile | ||||||
|       - lint-prettier |       - lint-prettier | ||||||
|       - test |       - test | ||||||
|       - build |       - build | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
|  |  | ||||||
| from os import environ | from os import environ | ||||||
|  |  | ||||||
| __version__ = "2024.2.3" | __version__ = "2024.4.4" | ||||||
| ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -154,12 +154,18 @@ class GroupViewSet(UsedByMixin, ModelViewSet): | |||||||
|  |  | ||||||
|         pk = IntegerField(required=True) |         pk = IntegerField(required=True) | ||||||
|  |  | ||||||
|     queryset = Group.objects.all().select_related("parent").prefetch_related("users") |     queryset = Group.objects.none() | ||||||
|     serializer_class = GroupSerializer |     serializer_class = GroupSerializer | ||||||
|     search_fields = ["name", "is_superuser"] |     search_fields = ["name", "is_superuser"] | ||||||
|     filterset_class = GroupFilter |     filterset_class = GroupFilter | ||||||
|     ordering = ["name"] |     ordering = ["name"] | ||||||
|  |  | ||||||
|  |     def get_queryset(self): | ||||||
|  |         base_qs = Group.objects.all().select_related("parent").prefetch_related("roles") | ||||||
|  |         if self.serializer_class(context={"request": self.request})._should_include_users: | ||||||
|  |             base_qs = base_qs.prefetch_related("users") | ||||||
|  |         return base_qs | ||||||
|  |  | ||||||
|     @extend_schema( |     @extend_schema( | ||||||
|         parameters=[ |         parameters=[ | ||||||
|             OpenApiParameter("include_users", bool, default=True), |             OpenApiParameter("include_users", bool, default=True), | ||||||
|  | |||||||
| @ -2,6 +2,7 @@ | |||||||
|  |  | ||||||
| from typing import Any | from typing import Any | ||||||
|  |  | ||||||
|  | from django.utils.timezone import now | ||||||
| from django_filters.rest_framework import DjangoFilterBackend | from django_filters.rest_framework import DjangoFilterBackend | ||||||
| from drf_spectacular.utils import OpenApiResponse, extend_schema, inline_serializer | from drf_spectacular.utils import OpenApiResponse, extend_schema, inline_serializer | ||||||
| from guardian.shortcuts import assign_perm, get_anonymous_user | from guardian.shortcuts import assign_perm, get_anonymous_user | ||||||
| @ -27,7 +28,6 @@ from authentik.core.models import ( | |||||||
|     TokenIntents, |     TokenIntents, | ||||||
|     User, |     User, | ||||||
|     default_token_duration, |     default_token_duration, | ||||||
|     token_expires_from_timedelta, |  | ||||||
| ) | ) | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.events.utils import model_to_dict | from authentik.events.utils import model_to_dict | ||||||
| @ -45,6 +45,13 @@ class TokenSerializer(ManagedSerializer, ModelSerializer): | |||||||
|         if SERIALIZER_CONTEXT_BLUEPRINT in self.context: |         if SERIALIZER_CONTEXT_BLUEPRINT in self.context: | ||||||
|             self.fields["key"] = CharField(required=False) |             self.fields["key"] = CharField(required=False) | ||||||
|  |  | ||||||
|  |     def validate_user(self, user: User): | ||||||
|  |         """Ensure user of token cannot be changed""" | ||||||
|  |         if self.instance and self.instance.user_id: | ||||||
|  |             if user.pk != self.instance.user_id: | ||||||
|  |                 raise ValidationError("User cannot be changed") | ||||||
|  |         return user | ||||||
|  |  | ||||||
|     def validate(self, attrs: dict[Any, str]) -> dict[Any, str]: |     def validate(self, attrs: dict[Any, str]) -> dict[Any, str]: | ||||||
|         """Ensure only API or App password tokens are created.""" |         """Ensure only API or App password tokens are created.""" | ||||||
|         request: Request = self.context.get("request") |         request: Request = self.context.get("request") | ||||||
| @ -68,15 +75,17 @@ class TokenSerializer(ManagedSerializer, ModelSerializer): | |||||||
|             max_token_lifetime_dt = default_token_duration() |             max_token_lifetime_dt = default_token_duration() | ||||||
|             if max_token_lifetime is not None: |             if max_token_lifetime is not None: | ||||||
|                 try: |                 try: | ||||||
|                     max_token_lifetime_dt = timedelta_from_string(max_token_lifetime) |                     max_token_lifetime_dt = now() + timedelta_from_string(max_token_lifetime) | ||||||
|                 except ValueError: |                 except ValueError: | ||||||
|                     max_token_lifetime_dt = default_token_duration() |                     pass | ||||||
|  |  | ||||||
|             if "expires" in attrs and attrs.get("expires") > token_expires_from_timedelta( |             if "expires" in attrs and attrs.get("expires") > max_token_lifetime_dt: | ||||||
|                 max_token_lifetime_dt |  | ||||||
|             ): |  | ||||||
|                 raise ValidationError( |                 raise ValidationError( | ||||||
|                     {"expires": f"Token expires exceeds maximum lifetime ({max_token_lifetime})."} |                     { | ||||||
|  |                         "expires": ( | ||||||
|  |                             f"Token expires exceeds maximum lifetime ({max_token_lifetime_dt} UTC)." | ||||||
|  |                         ) | ||||||
|  |                     } | ||||||
|                 ) |                 ) | ||||||
|         elif attrs.get("intent") == TokenIntents.INTENT_API: |         elif attrs.get("intent") == TokenIntents.INTENT_API: | ||||||
|             # For API tokens, expires cannot be overridden |             # For API tokens, expires cannot be overridden | ||||||
|  | |||||||
| @ -14,6 +14,7 @@ from rest_framework.request import Request | |||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
|  |  | ||||||
| from authentik.core.api.utils import PassiveSerializer | from authentik.core.api.utils import PassiveSerializer | ||||||
|  | from authentik.rbac.filters import ObjectFilter | ||||||
|  |  | ||||||
|  |  | ||||||
| class DeleteAction(Enum): | class DeleteAction(Enum): | ||||||
| @ -53,7 +54,7 @@ class UsedByMixin: | |||||||
|     @extend_schema( |     @extend_schema( | ||||||
|         responses={200: UsedBySerializer(many=True)}, |         responses={200: UsedBySerializer(many=True)}, | ||||||
|     ) |     ) | ||||||
|     @action(detail=True, pagination_class=None, filter_backends=[]) |     @action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) | ||||||
|     def used_by(self, request: Request, *args, **kwargs) -> Response: |     def used_by(self, request: Request, *args, **kwargs) -> Response: | ||||||
|         """Get a list of all objects that use this object""" |         """Get a list of all objects that use this object""" | ||||||
|         model: Model = self.get_object() |         model: Model = self.get_object() | ||||||
|  | |||||||
| @ -407,8 +407,11 @@ class UserViewSet(UsedByMixin, ModelViewSet): | |||||||
|     search_fields = ["username", "name", "is_active", "email", "uuid"] |     search_fields = ["username", "name", "is_active", "email", "uuid"] | ||||||
|     filterset_class = UsersFilter |     filterset_class = UsersFilter | ||||||
|  |  | ||||||
|     def get_queryset(self):  # pragma: no cover |     def get_queryset(self): | ||||||
|         return User.objects.all().exclude_anonymous().prefetch_related("ak_groups") |         base_qs = User.objects.all().exclude_anonymous() | ||||||
|  |         if self.serializer_class(context={"request": self.request})._should_include_groups: | ||||||
|  |             base_qs = base_qs.prefetch_related("ak_groups") | ||||||
|  |         return base_qs | ||||||
|  |  | ||||||
|     @extend_schema( |     @extend_schema( | ||||||
|         parameters=[ |         parameters=[ | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| """authentik core models""" | """authentik core models""" | ||||||
|  |  | ||||||
| from datetime import datetime, timedelta | from datetime import datetime | ||||||
| from hashlib import sha256 | from hashlib import sha256 | ||||||
| from typing import Any, Optional, Self | from typing import Any, Optional, Self | ||||||
| from uuid import uuid4 | from uuid import uuid4 | ||||||
| @ -68,11 +68,6 @@ def default_token_duration() -> datetime: | |||||||
|     return now() + timedelta_from_string(token_duration) |     return now() + timedelta_from_string(token_duration) | ||||||
|  |  | ||||||
|  |  | ||||||
| def token_expires_from_timedelta(dt: timedelta) -> datetime: |  | ||||||
|     """Return a `datetime.datetime` object with the duration of the Token""" |  | ||||||
|     return now() + dt |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def default_token_key() -> str: | def default_token_key() -> str: | ||||||
|     """Default token key""" |     """Default token key""" | ||||||
|     current_tenant = get_current_tenant() |     current_tenant = get_current_tenant() | ||||||
| @ -637,7 +632,7 @@ class UserSourceConnection(SerializerModel, CreatedUpdatedModel): | |||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def __str__(self) -> str: |     def __str__(self) -> str: | ||||||
|         return f"User-source connection (user={self.user.username}, source={self.source.slug})" |         return f"User-source connection (user={self.user_id}, source={self.source_id})" | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         unique_together = (("user", "source"),) |         unique_together = (("user", "source"),) | ||||||
|  | |||||||
| @ -13,7 +13,7 @@ from django.utils.translation import gettext as _ | |||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection | from authentik.core.models import Source, SourceUserMatchingModes, User, UserSourceConnection | ||||||
| from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostUserEnrollmentStage | from authentik.core.sources.stage import PLAN_CONTEXT_SOURCES_CONNECTION, PostSourceStage | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.flows.exceptions import FlowNonApplicableException | from authentik.flows.exceptions import FlowNonApplicableException | ||||||
| from authentik.flows.models import Flow, FlowToken, Stage, in_memory_stage | from authentik.flows.models import Flow, FlowToken, Stage, in_memory_stage | ||||||
| @ -100,8 +100,6 @@ class SourceFlowManager: | |||||||
|         if self.request.user.is_authenticated: |         if self.request.user.is_authenticated: | ||||||
|             new_connection.user = self.request.user |             new_connection.user = self.request.user | ||||||
|             new_connection = self.update_connection(new_connection, **kwargs) |             new_connection = self.update_connection(new_connection, **kwargs) | ||||||
|  |  | ||||||
|             new_connection.save() |  | ||||||
|             return Action.LINK, new_connection |             return Action.LINK, new_connection | ||||||
|  |  | ||||||
|         existing_connections = self.connection_type.objects.filter( |         existing_connections = self.connection_type.objects.filter( | ||||||
| @ -148,7 +146,6 @@ class SourceFlowManager: | |||||||
|         ]: |         ]: | ||||||
|             new_connection.user = user |             new_connection.user = user | ||||||
|             new_connection = self.update_connection(new_connection, **kwargs) |             new_connection = self.update_connection(new_connection, **kwargs) | ||||||
|             new_connection.save() |  | ||||||
|             return Action.LINK, new_connection |             return Action.LINK, new_connection | ||||||
|         if self.source.user_matching_mode in [ |         if self.source.user_matching_mode in [ | ||||||
|             SourceUserMatchingModes.EMAIL_DENY, |             SourceUserMatchingModes.EMAIL_DENY, | ||||||
| @ -209,13 +206,9 @@ class SourceFlowManager: | |||||||
|  |  | ||||||
|     def get_stages_to_append(self, flow: Flow) -> list[Stage]: |     def get_stages_to_append(self, flow: Flow) -> list[Stage]: | ||||||
|         """Hook to override stages which are appended to the flow""" |         """Hook to override stages which are appended to the flow""" | ||||||
|         if not self.source.enrollment_flow: |  | ||||||
|             return [] |  | ||||||
|         if flow.slug == self.source.enrollment_flow.slug: |  | ||||||
|         return [ |         return [ | ||||||
|                 in_memory_stage(PostUserEnrollmentStage), |             in_memory_stage(PostSourceStage), | ||||||
|         ] |         ] | ||||||
|         return [] |  | ||||||
|  |  | ||||||
|     def _prepare_flow( |     def _prepare_flow( | ||||||
|         self, |         self, | ||||||
| @ -269,6 +262,9 @@ class SourceFlowManager: | |||||||
|             ) |             ) | ||||||
|         # We run the Flow planner here so we can pass the Pending user in the context |         # We run the Flow planner here so we can pass the Pending user in the context | ||||||
|         planner = FlowPlanner(flow) |         planner = FlowPlanner(flow) | ||||||
|  |         # We append some stages so the initial flow we get might be empty | ||||||
|  |         planner.allow_empty_flows = True | ||||||
|  |         planner.use_cache = False | ||||||
|         plan = planner.plan(self.request, kwargs) |         plan = planner.plan(self.request, kwargs) | ||||||
|         for stage in self.get_stages_to_append(flow): |         for stage in self.get_stages_to_append(flow): | ||||||
|             plan.append_stage(stage) |             plan.append_stage(stage) | ||||||
| @ -327,7 +323,7 @@ class SourceFlowManager: | |||||||
|             reverse( |             reverse( | ||||||
|                 "authentik_core:if-user", |                 "authentik_core:if-user", | ||||||
|             ) |             ) | ||||||
|             + f"#/settings;page-{self.source.slug}" |             + "#/settings;page-sources" | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     def handle_enroll( |     def handle_enroll( | ||||||
|  | |||||||
| @ -10,7 +10,7 @@ from authentik.flows.stage import StageView | |||||||
| PLAN_CONTEXT_SOURCES_CONNECTION = "goauthentik.io/sources/connection" | PLAN_CONTEXT_SOURCES_CONNECTION = "goauthentik.io/sources/connection" | ||||||
|  |  | ||||||
|  |  | ||||||
| class PostUserEnrollmentStage(StageView): | class PostSourceStage(StageView): | ||||||
|     """Dynamically injected stage which saves the Connection after |     """Dynamically injected stage which saves the Connection after | ||||||
|     the user has been enrolled.""" |     the user has been enrolled.""" | ||||||
|  |  | ||||||
| @ -21,7 +21,9 @@ class PostUserEnrollmentStage(StageView): | |||||||
|         ] |         ] | ||||||
|         user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] |         user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] | ||||||
|         connection.user = user |         connection.user = user | ||||||
|  |         linked = connection.pk is None | ||||||
|         connection.save() |         connection.save() | ||||||
|  |         if linked: | ||||||
|             Event.new( |             Event.new( | ||||||
|                 EventAction.SOURCE_LINKED, |                 EventAction.SOURCE_LINKED, | ||||||
|                 message="Linked Source", |                 message="Linked Source", | ||||||
|  | |||||||
| @ -2,7 +2,9 @@ | |||||||
|  |  | ||||||
| from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||||
|  |  | ||||||
|  | from django.conf import ImproperlyConfigured | ||||||
| from django.contrib.sessions.backends.cache import KEY_PREFIX | from django.contrib.sessions.backends.cache import KEY_PREFIX | ||||||
|  | from django.contrib.sessions.backends.db import SessionStore as DBSessionStore | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.utils.timezone import now | from django.utils.timezone import now | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| @ -15,6 +17,7 @@ from authentik.core.models import ( | |||||||
|     User, |     User, | ||||||
| ) | ) | ||||||
| from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task | from authentik.events.system_tasks import SystemTask, TaskStatus, prefill_task | ||||||
|  | from authentik.lib.config import CONFIG | ||||||
| from authentik.root.celery import CELERY_APP | from authentik.root.celery import CELERY_APP | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -39,6 +42,8 @@ def clean_expired_models(self: SystemTask): | |||||||
|     amount = 0 |     amount = 0 | ||||||
|  |  | ||||||
|     for session in AuthenticatedSession.objects.all(): |     for session in AuthenticatedSession.objects.all(): | ||||||
|  |         match CONFIG.get("session_storage", "cache"): | ||||||
|  |             case "cache": | ||||||
|                 cache_key = f"{KEY_PREFIX}{session.session_key}" |                 cache_key = f"{KEY_PREFIX}{session.session_key}" | ||||||
|                 value = None |                 value = None | ||||||
|                 try: |                 try: | ||||||
| @ -49,6 +54,19 @@ def clean_expired_models(self: SystemTask): | |||||||
|                 if not value: |                 if not value: | ||||||
|                     session.delete() |                     session.delete() | ||||||
|                     amount += 1 |                     amount += 1 | ||||||
|  |             case "db": | ||||||
|  |                 if not ( | ||||||
|  |                     DBSessionStore.get_model_class() | ||||||
|  |                     .objects.filter(session_key=session.session_key, expire_date__gt=now()) | ||||||
|  |                     .exists() | ||||||
|  |                 ): | ||||||
|  |                     session.delete() | ||||||
|  |                     amount += 1 | ||||||
|  |             case _: | ||||||
|  |                 # Should never happen, as we check for other values in authentik/root/settings.py | ||||||
|  |                 raise ImproperlyConfigured( | ||||||
|  |                     "Invalid session_storage setting, allowed values are db and cache" | ||||||
|  |                 ) | ||||||
|     LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount) |     LOGGER.debug("Expired sessions", model=AuthenticatedSession, amount=amount) | ||||||
|  |  | ||||||
|     messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}") |     messages.append(f"Expired {amount} {AuthenticatedSession._meta.verbose_name_plural}") | ||||||
|  | |||||||
| @ -5,7 +5,7 @@ from guardian.shortcuts import assign_perm | |||||||
| from rest_framework.test import APITestCase | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.core.tests.utils import create_test_user | from authentik.core.tests.utils import create_test_admin_user, create_test_user | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -16,6 +16,13 @@ class TestGroupsAPI(APITestCase): | |||||||
|         self.login_user = create_test_user() |         self.login_user = create_test_user() | ||||||
|         self.user = User.objects.create(username="test-user") |         self.user = User.objects.create(username="test-user") | ||||||
|  |  | ||||||
|  |     def test_list_with_users(self): | ||||||
|  |         """Test listing with users""" | ||||||
|  |         admin = create_test_admin_user() | ||||||
|  |         self.client.force_login(admin) | ||||||
|  |         response = self.client.get(reverse("authentik_api:group-list"), {"include_users": "true"}) | ||||||
|  |         self.assertEqual(response.status_code, 200) | ||||||
|  |  | ||||||
|     def test_add_user(self): |     def test_add_user(self): | ||||||
|         """Test add_user""" |         """Test add_user""" | ||||||
|         group = Group.objects.create(name=generate_id()) |         group = Group.objects.create(name=generate_id()) | ||||||
|  | |||||||
| @ -2,11 +2,15 @@ | |||||||
|  |  | ||||||
| from django.contrib.auth.models import AnonymousUser | from django.contrib.auth.models import AnonymousUser | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
|  | from django.urls import reverse | ||||||
| from guardian.utils import get_anonymous_user | from guardian.utils import get_anonymous_user | ||||||
|  |  | ||||||
| from authentik.core.models import SourceUserMatchingModes, User | from authentik.core.models import SourceUserMatchingModes, User | ||||||
| from authentik.core.sources.flow_manager import Action | from authentik.core.sources.flow_manager import Action | ||||||
|  | from authentik.core.sources.stage import PostSourceStage | ||||||
| from authentik.core.tests.utils import create_test_flow | from authentik.core.tests.utils import create_test_flow | ||||||
|  | from authentik.flows.planner import FlowPlan | ||||||
|  | from authentik.flows.views.executor import SESSION_KEY_PLAN | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.lib.tests.utils import get_request | from authentik.lib.tests.utils import get_request | ||||||
| from authentik.policies.denied import AccessDeniedResponse | from authentik.policies.denied import AccessDeniedResponse | ||||||
| @ -21,42 +25,62 @@ class TestSourceFlowManager(TestCase): | |||||||
|  |  | ||||||
|     def setUp(self) -> None: |     def setUp(self) -> None: | ||||||
|         super().setUp() |         super().setUp() | ||||||
|         self.source: OAuthSource = OAuthSource.objects.create(name="test") |         self.authentication_flow = create_test_flow() | ||||||
|  |         self.enrollment_flow = create_test_flow() | ||||||
|  |         self.source: OAuthSource = OAuthSource.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             slug=generate_id(), | ||||||
|  |             authentication_flow=self.authentication_flow, | ||||||
|  |             enrollment_flow=self.enrollment_flow, | ||||||
|  |         ) | ||||||
|         self.identifier = generate_id() |         self.identifier = generate_id() | ||||||
|  |  | ||||||
|     def test_unauthenticated_enroll(self): |     def test_unauthenticated_enroll(self): | ||||||
|         """Test un-authenticated user enrolling""" |         """Test un-authenticated user enrolling""" | ||||||
|         flow_manager = OAuthSourceFlowManager( |         request = get_request("/", user=AnonymousUser()) | ||||||
|             self.source, get_request("/", user=AnonymousUser()), self.identifier, {} |         flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {}) | ||||||
|         ) |  | ||||||
|         action, _ = flow_manager.get_action() |         action, _ = flow_manager.get_action() | ||||||
|         self.assertEqual(action, Action.ENROLL) |         self.assertEqual(action, Action.ENROLL) | ||||||
|         flow_manager.get_flow() |         response = flow_manager.get_flow() | ||||||
|  |         self.assertEqual(response.status_code, 302) | ||||||
|  |         flow_plan: FlowPlan = request.session[SESSION_KEY_PLAN] | ||||||
|  |         self.assertEqual(flow_plan.bindings[0].stage.view, PostSourceStage) | ||||||
|  |  | ||||||
|     def test_unauthenticated_auth(self): |     def test_unauthenticated_auth(self): | ||||||
|         """Test un-authenticated user authenticating""" |         """Test un-authenticated user authenticating""" | ||||||
|         UserOAuthSourceConnection.objects.create( |         UserOAuthSourceConnection.objects.create( | ||||||
|             user=get_anonymous_user(), source=self.source, identifier=self.identifier |             user=get_anonymous_user(), source=self.source, identifier=self.identifier | ||||||
|         ) |         ) | ||||||
|  |         request = get_request("/", user=AnonymousUser()) | ||||||
|         flow_manager = OAuthSourceFlowManager( |         flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {}) | ||||||
|             self.source, get_request("/", user=AnonymousUser()), self.identifier, {} |  | ||||||
|         ) |  | ||||||
|         action, _ = flow_manager.get_action() |         action, _ = flow_manager.get_action() | ||||||
|         self.assertEqual(action, Action.AUTH) |         self.assertEqual(action, Action.AUTH) | ||||||
|         flow_manager.get_flow() |         response = flow_manager.get_flow() | ||||||
|  |         self.assertEqual(response.status_code, 302) | ||||||
|  |         flow_plan: FlowPlan = request.session[SESSION_KEY_PLAN] | ||||||
|  |         self.assertEqual(flow_plan.bindings[0].stage.view, PostSourceStage) | ||||||
|  |  | ||||||
|     def test_authenticated_link(self): |     def test_authenticated_link(self): | ||||||
|         """Test authenticated user linking""" |         """Test authenticated user linking""" | ||||||
|         UserOAuthSourceConnection.objects.create( |  | ||||||
|             user=get_anonymous_user(), source=self.source, identifier=self.identifier |  | ||||||
|         ) |  | ||||||
|         user = User.objects.create(username="foo", email="foo@bar.baz") |         user = User.objects.create(username="foo", email="foo@bar.baz") | ||||||
|         flow_manager = OAuthSourceFlowManager( |         request = get_request("/", user=user) | ||||||
|             self.source, get_request("/", user=user), self.identifier, {} |         flow_manager = OAuthSourceFlowManager(self.source, request, self.identifier, {}) | ||||||
|         ) |         action, connection = flow_manager.get_action() | ||||||
|         action, _ = flow_manager.get_action() |  | ||||||
|         self.assertEqual(action, Action.LINK) |         self.assertEqual(action, Action.LINK) | ||||||
|  |         self.assertIsNone(connection.pk) | ||||||
|  |         response = flow_manager.get_flow() | ||||||
|  |         self.assertEqual(response.status_code, 302) | ||||||
|  |         self.assertEqual( | ||||||
|  |             response.url, | ||||||
|  |             reverse("authentik_core:if-user") + "#/settings;page-sources", | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def test_unauthenticated_link(self): | ||||||
|  |         """Test un-authenticated user linking""" | ||||||
|  |         flow_manager = OAuthSourceFlowManager(self.source, get_request("/"), self.identifier, {}) | ||||||
|  |         action, connection = flow_manager.get_action() | ||||||
|  |         self.assertEqual(action, Action.LINK) | ||||||
|  |         self.assertIsNone(connection.pk) | ||||||
|         flow_manager.get_flow() |         flow_manager.get_flow() | ||||||
|  |  | ||||||
|     def test_unauthenticated_enroll_email(self): |     def test_unauthenticated_enroll_email(self): | ||||||
|  | |||||||
| @ -13,9 +13,8 @@ from authentik.core.models import ( | |||||||
|     USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME, |     USER_ATTRIBUTE_TOKEN_MAXIMUM_LIFETIME, | ||||||
|     Token, |     Token, | ||||||
|     TokenIntents, |     TokenIntents, | ||||||
|     User, |  | ||||||
| ) | ) | ||||||
| from authentik.core.tests.utils import create_test_admin_user | from authentik.core.tests.utils import create_test_admin_user, create_test_user | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -24,7 +23,7 @@ class TestTokenAPI(APITestCase): | |||||||
|  |  | ||||||
|     def setUp(self) -> None: |     def setUp(self) -> None: | ||||||
|         super().setUp() |         super().setUp() | ||||||
|         self.user = User.objects.create(username="testuser") |         self.user = create_test_user() | ||||||
|         self.admin = create_test_admin_user() |         self.admin = create_test_admin_user() | ||||||
|         self.client.force_login(self.user) |         self.client.force_login(self.user) | ||||||
|  |  | ||||||
| @ -154,6 +153,24 @@ class TestTokenAPI(APITestCase): | |||||||
|         self.assertEqual(token.expiring, True) |         self.assertEqual(token.expiring, True) | ||||||
|         self.assertNotEqual(token.expires.timestamp(), expires.timestamp()) |         self.assertNotEqual(token.expires.timestamp(), expires.timestamp()) | ||||||
|  |  | ||||||
|  |     def test_token_change_user(self): | ||||||
|  |         """Test creating a token and then changing the user""" | ||||||
|  |         ident = generate_id() | ||||||
|  |         response = self.client.post(reverse("authentik_api:token-list"), {"identifier": ident}) | ||||||
|  |         self.assertEqual(response.status_code, 201) | ||||||
|  |         token = Token.objects.get(identifier=ident) | ||||||
|  |         self.assertEqual(token.user, self.user) | ||||||
|  |         self.assertEqual(token.intent, TokenIntents.INTENT_API) | ||||||
|  |         self.assertEqual(token.expiring, True) | ||||||
|  |         self.assertTrue(self.user.has_perm("authentik_core.view_token_key", token)) | ||||||
|  |         response = self.client.put( | ||||||
|  |             reverse("authentik_api:token-detail", kwargs={"identifier": ident}), | ||||||
|  |             data={"identifier": "user_token_poc_v3", "intent": "api", "user": self.admin.pk}, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 400) | ||||||
|  |         token.refresh_from_db() | ||||||
|  |         self.assertEqual(token.user, self.user) | ||||||
|  |  | ||||||
|     def test_list(self): |     def test_list(self): | ||||||
|         """Test Token List (Test normal authentication)""" |         """Test Token List (Test normal authentication)""" | ||||||
|         Token.objects.all().delete() |         Token.objects.all().delete() | ||||||
|  | |||||||
| @ -41,6 +41,12 @@ class TestUsersAPI(APITestCase): | |||||||
|         ) |         ) | ||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
|  |  | ||||||
|  |     def test_list_with_groups(self): | ||||||
|  |         """Test listing with groups""" | ||||||
|  |         self.client.force_login(self.admin) | ||||||
|  |         response = self.client.get(reverse("authentik_api:user-list"), {"include_groups": "true"}) | ||||||
|  |         self.assertEqual(response.status_code, 200) | ||||||
|  |  | ||||||
|     def test_metrics(self): |     def test_metrics(self): | ||||||
|         """Test user's metrics""" |         """Test user's metrics""" | ||||||
|         self.client.force_login(self.admin) |         self.client.force_login(self.admin) | ||||||
|  | |||||||
| @ -8,7 +8,6 @@ from rest_framework.test import APITestCase | |||||||
|  |  | ||||||
| from authentik.core.models import User | from authentik.core.models import User | ||||||
| from authentik.core.tests.utils import create_test_admin_user | from authentik.core.tests.utils import create_test_admin_user | ||||||
| from authentik.lib.config import CONFIG |  | ||||||
| from authentik.tenants.utils import get_current_tenant | from authentik.tenants.utils import get_current_tenant | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -25,7 +24,6 @@ class TestUsersAvatars(APITestCase): | |||||||
|         tenant.avatars = mode |         tenant.avatars = mode | ||||||
|         tenant.save() |         tenant.save() | ||||||
|  |  | ||||||
|     @CONFIG.patch("avatars", "none") |  | ||||||
|     def test_avatars_none(self): |     def test_avatars_none(self): | ||||||
|         """Test avatars none""" |         """Test avatars none""" | ||||||
|         self.set_avatar_mode("none") |         self.set_avatar_mode("none") | ||||||
|  | |||||||
| @ -4,7 +4,7 @@ from django.utils.text import slugify | |||||||
|  |  | ||||||
| from authentik.brands.models import Brand | from authentik.brands.models import Brand | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.crypto.builder import CertificateBuilder | from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.flows.models import Flow, FlowDesignation | from authentik.flows.models import Flow, FlowDesignation | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| @ -50,12 +50,10 @@ def create_test_brand(**kwargs) -> Brand: | |||||||
|     return Brand.objects.create(domain=uid, default=True, **kwargs) |     return Brand.objects.create(domain=uid, default=True, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_test_cert(use_ec_private_key=False) -> CertificateKeyPair: | def create_test_cert(alg=PrivateKeyAlg.RSA) -> CertificateKeyPair: | ||||||
|     """Generate a certificate for testing""" |     """Generate a certificate for testing""" | ||||||
|     builder = CertificateBuilder( |     builder = CertificateBuilder(f"{generate_id()}.self-signed.goauthentik.io") | ||||||
|         name=f"{generate_id()}.self-signed.goauthentik.io", |     builder.alg = alg | ||||||
|         use_ec_private_key=use_ec_private_key, |  | ||||||
|     ) |  | ||||||
|     builder.build( |     builder.build( | ||||||
|         subject_alt_names=[f"{generate_id()}.self-signed.goauthentik.io"], |         subject_alt_names=[f"{generate_id()}.self-signed.goauthentik.io"], | ||||||
|         validity_days=360, |         validity_days=360, | ||||||
|  | |||||||
| @ -14,7 +14,13 @@ from drf_spectacular.types import OpenApiTypes | |||||||
| from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema | ||||||
| from rest_framework.decorators import action | from rest_framework.decorators import action | ||||||
| from rest_framework.exceptions import ValidationError | from rest_framework.exceptions import ValidationError | ||||||
| from rest_framework.fields import CharField, DateTimeField, IntegerField, SerializerMethodField | from rest_framework.fields import ( | ||||||
|  |     CharField, | ||||||
|  |     ChoiceField, | ||||||
|  |     DateTimeField, | ||||||
|  |     IntegerField, | ||||||
|  |     SerializerMethodField, | ||||||
|  | ) | ||||||
| from rest_framework.filters import OrderingFilter, SearchFilter | from rest_framework.filters import OrderingFilter, SearchFilter | ||||||
| from rest_framework.request import Request | from rest_framework.request import Request | ||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
| @ -26,10 +32,11 @@ from authentik.api.authorization import SecretKeyFilter | |||||||
| from authentik.core.api.used_by import UsedByMixin | from authentik.core.api.used_by import UsedByMixin | ||||||
| from authentik.core.api.utils import PassiveSerializer | from authentik.core.api.utils import PassiveSerializer | ||||||
| from authentik.crypto.apps import MANAGED_KEY | from authentik.crypto.apps import MANAGED_KEY | ||||||
| from authentik.crypto.builder import CertificateBuilder | from authentik.crypto.builder import CertificateBuilder, PrivateKeyAlg | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.events.models import Event, EventAction | from authentik.events.models import Event, EventAction | ||||||
| from authentik.rbac.decorators import permission_required | from authentik.rbac.decorators import permission_required | ||||||
|  | from authentik.rbac.filters import ObjectFilter | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| @ -178,6 +185,7 @@ class CertificateGenerationSerializer(PassiveSerializer): | |||||||
|     common_name = CharField() |     common_name = CharField() | ||||||
|     subject_alt_name = CharField(required=False, allow_blank=True, label=_("Subject-alt name")) |     subject_alt_name = CharField(required=False, allow_blank=True, label=_("Subject-alt name")) | ||||||
|     validity_days = IntegerField(initial=365) |     validity_days = IntegerField(initial=365) | ||||||
|  |     alg = ChoiceField(default=PrivateKeyAlg.RSA, choices=PrivateKeyAlg.choices) | ||||||
|  |  | ||||||
|  |  | ||||||
| class CertificateKeyPairFilter(FilterSet): | class CertificateKeyPairFilter(FilterSet): | ||||||
| @ -240,6 +248,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet): | |||||||
|         raw_san = data.validated_data.get("subject_alt_name", "") |         raw_san = data.validated_data.get("subject_alt_name", "") | ||||||
|         sans = raw_san.split(",") if raw_san != "" else [] |         sans = raw_san.split(",") if raw_san != "" else [] | ||||||
|         builder = CertificateBuilder(data.validated_data["common_name"]) |         builder = CertificateBuilder(data.validated_data["common_name"]) | ||||||
|  |         builder.alg = data.validated_data["alg"] | ||||||
|         builder.build( |         builder.build( | ||||||
|             subject_alt_names=sans, |             subject_alt_names=sans, | ||||||
|             validity_days=int(data.validated_data["validity_days"]), |             validity_days=int(data.validated_data["validity_days"]), | ||||||
| @ -258,7 +267,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet): | |||||||
|         ], |         ], | ||||||
|         responses={200: CertificateDataSerializer(many=False)}, |         responses={200: CertificateDataSerializer(many=False)}, | ||||||
|     ) |     ) | ||||||
|     @action(detail=True, pagination_class=None, filter_backends=[]) |     @action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) | ||||||
|     def view_certificate(self, request: Request, pk: str) -> Response: |     def view_certificate(self, request: Request, pk: str) -> Response: | ||||||
|         """Return certificate-key pairs certificate and log access""" |         """Return certificate-key pairs certificate and log access""" | ||||||
|         certificate: CertificateKeyPair = self.get_object() |         certificate: CertificateKeyPair = self.get_object() | ||||||
| @ -288,7 +297,7 @@ class CertificateKeyPairViewSet(UsedByMixin, ModelViewSet): | |||||||
|         ], |         ], | ||||||
|         responses={200: CertificateDataSerializer(many=False)}, |         responses={200: CertificateDataSerializer(many=False)}, | ||||||
|     ) |     ) | ||||||
|     @action(detail=True, pagination_class=None, filter_backends=[]) |     @action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) | ||||||
|     def view_private_key(self, request: Request, pk: str) -> Response: |     def view_private_key(self, request: Request, pk: str) -> Response: | ||||||
|         """Return certificate-key pairs private key and log access""" |         """Return certificate-key pairs private key and log access""" | ||||||
|         certificate: CertificateKeyPair = self.get_object() |         certificate: CertificateKeyPair = self.get_object() | ||||||
|  | |||||||
| @ -9,20 +9,28 @@ from cryptography.hazmat.primitives import hashes, serialization | |||||||
| from cryptography.hazmat.primitives.asymmetric import ec, rsa | from cryptography.hazmat.primitives.asymmetric import ec, rsa | ||||||
| from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes | from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes | ||||||
| from cryptography.x509.oid import NameOID | from cryptography.x509.oid import NameOID | ||||||
|  | from django.db import models | ||||||
|  | from django.utils.translation import gettext_lazy as _ | ||||||
|  |  | ||||||
| from authentik import __version__ | from authentik import __version__ | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class PrivateKeyAlg(models.TextChoices): | ||||||
|  |     """Algorithm to create private key with""" | ||||||
|  |  | ||||||
|  |     RSA = "rsa", _("rsa") | ||||||
|  |     ECDSA = "ecdsa", _("ecdsa") | ||||||
|  |  | ||||||
|  |  | ||||||
| class CertificateBuilder: | class CertificateBuilder: | ||||||
|     """Build self-signed certificates""" |     """Build self-signed certificates""" | ||||||
|  |  | ||||||
|     common_name: str |     common_name: str | ||||||
|  |     alg: PrivateKeyAlg | ||||||
|  |  | ||||||
|     _use_ec_private_key: bool |     def __init__(self, name: str): | ||||||
|  |         self.alg = PrivateKeyAlg.RSA | ||||||
|     def __init__(self, name: str, use_ec_private_key=False): |  | ||||||
|         self._use_ec_private_key = use_ec_private_key |  | ||||||
|         self.__public_key = None |         self.__public_key = None | ||||||
|         self.__private_key = None |         self.__private_key = None | ||||||
|         self.__builder = None |         self.__builder = None | ||||||
| @ -42,11 +50,13 @@ class CertificateBuilder: | |||||||
|  |  | ||||||
|     def generate_private_key(self) -> PrivateKeyTypes: |     def generate_private_key(self) -> PrivateKeyTypes: | ||||||
|         """Generate private key""" |         """Generate private key""" | ||||||
|         if self._use_ec_private_key: |         if self.alg == PrivateKeyAlg.ECDSA: | ||||||
|             return ec.generate_private_key(curve=ec.SECP256R1()) |             return ec.generate_private_key(curve=ec.SECP256R1()) | ||||||
|  |         if self.alg == PrivateKeyAlg.RSA: | ||||||
|             return rsa.generate_private_key( |             return rsa.generate_private_key( | ||||||
|                 public_exponent=65537, key_size=4096, backend=default_backend() |                 public_exponent=65537, key_size=4096, backend=default_backend() | ||||||
|             ) |             ) | ||||||
|  |         raise ValueError(f"Invalid alg: {self.alg}") | ||||||
|  |  | ||||||
|     def build( |     def build( | ||||||
|         self, |         self, | ||||||
|  | |||||||
| @ -214,6 +214,46 @@ class TestCrypto(APITestCase): | |||||||
|         self.assertEqual(200, response.status_code) |         self.assertEqual(200, response.status_code) | ||||||
|         self.assertIn("Content-Disposition", response) |         self.assertIn("Content-Disposition", response) | ||||||
|  |  | ||||||
|  |     def test_certificate_download_denied(self): | ||||||
|  |         """Test certificate export (download)""" | ||||||
|  |         self.client.logout() | ||||||
|  |         keypair = create_test_cert() | ||||||
|  |         response = self.client.get( | ||||||
|  |             reverse( | ||||||
|  |                 "authentik_api:certificatekeypair-view-certificate", | ||||||
|  |                 kwargs={"pk": keypair.pk}, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(403, response.status_code) | ||||||
|  |         response = self.client.get( | ||||||
|  |             reverse( | ||||||
|  |                 "authentik_api:certificatekeypair-view-certificate", | ||||||
|  |                 kwargs={"pk": keypair.pk}, | ||||||
|  |             ), | ||||||
|  |             data={"download": True}, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(403, response.status_code) | ||||||
|  |  | ||||||
|  |     def test_private_key_download_denied(self): | ||||||
|  |         """Test private_key export (download)""" | ||||||
|  |         self.client.logout() | ||||||
|  |         keypair = create_test_cert() | ||||||
|  |         response = self.client.get( | ||||||
|  |             reverse( | ||||||
|  |                 "authentik_api:certificatekeypair-view-private-key", | ||||||
|  |                 kwargs={"pk": keypair.pk}, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(403, response.status_code) | ||||||
|  |         response = self.client.get( | ||||||
|  |             reverse( | ||||||
|  |                 "authentik_api:certificatekeypair-view-private-key", | ||||||
|  |                 kwargs={"pk": keypair.pk}, | ||||||
|  |             ), | ||||||
|  |             data={"download": True}, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(403, response.status_code) | ||||||
|  |  | ||||||
|     def test_used_by(self): |     def test_used_by(self): | ||||||
|         """Test used_by endpoint""" |         """Test used_by endpoint""" | ||||||
|         self.client.force_login(create_test_admin_user()) |         self.client.force_login(create_test_admin_user()) | ||||||
| @ -246,6 +286,26 @@ class TestCrypto(APITestCase): | |||||||
|             ], |             ], | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_used_by_denied(self): | ||||||
|  |         """Test used_by endpoint""" | ||||||
|  |         self.client.logout() | ||||||
|  |         keypair = create_test_cert() | ||||||
|  |         OAuth2Provider.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             client_id="test", | ||||||
|  |             client_secret=generate_key(), | ||||||
|  |             authorization_flow=create_test_flow(), | ||||||
|  |             redirect_uris="http://localhost", | ||||||
|  |             signing_key=keypair, | ||||||
|  |         ) | ||||||
|  |         response = self.client.get( | ||||||
|  |             reverse( | ||||||
|  |                 "authentik_api:certificatekeypair-used-by", | ||||||
|  |                 kwargs={"pk": keypair.pk}, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(403, response.status_code) | ||||||
|  |  | ||||||
|     def test_discovery(self): |     def test_discovery(self): | ||||||
|         """Test certificate discovery""" |         """Test certificate discovery""" | ||||||
|         name = generate_id() |         name = generate_id() | ||||||
|  | |||||||
| @ -2,11 +2,12 @@ | |||||||
|  |  | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from functools import partial | from functools import partial | ||||||
|  | from typing import Any | ||||||
|  |  | ||||||
| from django.apps.registry import apps | from django.apps.registry import apps | ||||||
| from django.core.files import File | from django.core.files import File | ||||||
| from django.db import connection | from django.db import connection | ||||||
| from django.db.models import Model | from django.db.models import ManyToManyRel, Model | ||||||
| from django.db.models.expressions import BaseExpression, Combinable | from django.db.models.expressions import BaseExpression, Combinable | ||||||
| from django.db.models.signals import post_init | from django.db.models.signals import post_init | ||||||
| from django.http import HttpRequest | from django.http import HttpRequest | ||||||
| @ -44,7 +45,7 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | |||||||
|         post_init.disconnect(dispatch_uid=request.request_id) |         post_init.disconnect(dispatch_uid=request.request_id) | ||||||
|  |  | ||||||
|     def serialize_simple(self, model: Model) -> dict: |     def serialize_simple(self, model: Model) -> dict: | ||||||
|         """Serialize a model in a very simple way. No ForeginKeys or other relationships are |         """Serialize a model in a very simple way. No ForeignKeys or other relationships are | ||||||
|         resolved""" |         resolved""" | ||||||
|         data = {} |         data = {} | ||||||
|         deferred_fields = model.get_deferred_fields() |         deferred_fields = model.get_deferred_fields() | ||||||
| @ -70,6 +71,9 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | |||||||
|         for key, value in before.items(): |         for key, value in before.items(): | ||||||
|             if after.get(key) != value: |             if after.get(key) != value: | ||||||
|                 diff[key] = {"previous_value": value, "new_value": after.get(key)} |                 diff[key] = {"previous_value": value, "new_value": after.get(key)} | ||||||
|  |         for key, value in after.items(): | ||||||
|  |             if key not in before and key not in diff and before.get(key) != value: | ||||||
|  |                 diff[key] = {"previous_value": before.get(key), "new_value": value} | ||||||
|         return sanitize_item(diff) |         return sanitize_item(diff) | ||||||
|  |  | ||||||
|     def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_): |     def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_): | ||||||
| @ -98,8 +102,37 @@ class EnterpriseAuditMiddleware(AuditMiddleware): | |||||||
|         thread_kwargs = {} |         thread_kwargs = {} | ||||||
|         if hasattr(instance, "_previous_state") or created: |         if hasattr(instance, "_previous_state") or created: | ||||||
|             prev_state = getattr(instance, "_previous_state", {}) |             prev_state = getattr(instance, "_previous_state", {}) | ||||||
|  |             if created: | ||||||
|  |                 prev_state = {} | ||||||
|             # Get current state |             # Get current state | ||||||
|             new_state = self.serialize_simple(instance) |             new_state = self.serialize_simple(instance) | ||||||
|             diff = self.diff(prev_state, new_state) |             diff = self.diff(prev_state, new_state) | ||||||
|             thread_kwargs["diff"] = diff |             thread_kwargs["diff"] = diff | ||||||
|         return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_) |         return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_) | ||||||
|  |  | ||||||
|  |     def m2m_changed_handler(  # noqa: PLR0913 | ||||||
|  |         self, | ||||||
|  |         request: HttpRequest, | ||||||
|  |         sender, | ||||||
|  |         instance: Model, | ||||||
|  |         action: str, | ||||||
|  |         pk_set: set[Any], | ||||||
|  |         thread_kwargs: dict | None = None, | ||||||
|  |         **_, | ||||||
|  |     ): | ||||||
|  |         thread_kwargs = {} | ||||||
|  |         m2m_field = None | ||||||
|  |         # For the audit log we don't care about `pre_` or `post_` so we trim that part off | ||||||
|  |         _, _, action_direction = action.partition("_") | ||||||
|  |         # resolve the "through" model to an actual field | ||||||
|  |         for field in instance._meta.get_fields(): | ||||||
|  |             if not isinstance(field, ManyToManyRel): | ||||||
|  |                 continue | ||||||
|  |             if field.through == sender: | ||||||
|  |                 m2m_field = field | ||||||
|  |         if m2m_field: | ||||||
|  |             # If we're clearing we just set the "flag" to True | ||||||
|  |             if action_direction == "clear": | ||||||
|  |                 pk_set = True | ||||||
|  |             thread_kwargs["diff"] = {m2m_field.related_name: {action_direction: pk_set}} | ||||||
|  |         return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs) | ||||||
|  | |||||||
| @ -1,9 +1,22 @@ | |||||||
|  | from unittest.mock import PropertyMock, patch | ||||||
|  |  | ||||||
| from django.apps import apps | from django.apps import apps | ||||||
| from django.conf import settings | from django.conf import settings | ||||||
| from django.test import TestCase | from django.urls import reverse | ||||||
|  | from rest_framework.test import APITestCase | ||||||
|  |  | ||||||
|  | from authentik.core.models import Group, User | ||||||
|  | from authentik.core.tests.utils import create_test_admin_user | ||||||
|  | from authentik.events.models import Event, EventAction | ||||||
|  | from authentik.events.utils import sanitize_item | ||||||
|  | from authentik.lib.generators import generate_id | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestEnterpriseAudit(TestCase): | class TestEnterpriseAudit(APITestCase): | ||||||
|  |     """Test audit middleware""" | ||||||
|  |  | ||||||
|  |     def setUp(self) -> None: | ||||||
|  |         self.user = create_test_admin_user() | ||||||
|  |  | ||||||
|     def test_import(self): |     def test_import(self): | ||||||
|         """Ensure middleware is imported when app.ready is called""" |         """Ensure middleware is imported when app.ready is called""" | ||||||
| @ -16,3 +29,182 @@ class TestEnterpriseAudit(TestCase): | |||||||
|         self.assertIn( |         self.assertIn( | ||||||
|             "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware", settings.MIDDLEWARE |             "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware", settings.MIDDLEWARE | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     @patch( | ||||||
|  |         "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", | ||||||
|  |         PropertyMock(return_value=True), | ||||||
|  |     ) | ||||||
|  |     def test_create(self): | ||||||
|  |         """Test create audit log""" | ||||||
|  |         self.client.force_login(self.user) | ||||||
|  |         username = generate_id() | ||||||
|  |         response = self.client.post( | ||||||
|  |             reverse("authentik_api:user-list"), | ||||||
|  |             data={"name": generate_id(), "username": username, "groups": [], "path": "foo"}, | ||||||
|  |         ) | ||||||
|  |         user = User.objects.get(username=username) | ||||||
|  |         self.assertEqual(response.status_code, 201) | ||||||
|  |         events = Event.objects.filter( | ||||||
|  |             action=EventAction.MODEL_CREATED, | ||||||
|  |             context__model__model_name="user", | ||||||
|  |             context__model__app="authentik_core", | ||||||
|  |             context__model__pk=user.pk, | ||||||
|  |         ) | ||||||
|  |         event = events.first() | ||||||
|  |         self.assertIsNotNone(event) | ||||||
|  |         self.assertIsNotNone(event.context["diff"]) | ||||||
|  |         diff = event.context["diff"] | ||||||
|  |         self.assertEqual( | ||||||
|  |             diff, | ||||||
|  |             { | ||||||
|  |                 "name": { | ||||||
|  |                     "new_value": user.name, | ||||||
|  |                     "previous_value": None, | ||||||
|  |                 }, | ||||||
|  |                 "path": {"new_value": "foo", "previous_value": None}, | ||||||
|  |                 "type": {"new_value": "internal", "previous_value": None}, | ||||||
|  |                 "uuid": { | ||||||
|  |                     "new_value": user.uuid.hex, | ||||||
|  |                     "previous_value": None, | ||||||
|  |                 }, | ||||||
|  |                 "email": {"new_value": "", "previous_value": None}, | ||||||
|  |                 "username": { | ||||||
|  |                     "new_value": user.username, | ||||||
|  |                     "previous_value": None, | ||||||
|  |                 }, | ||||||
|  |                 "is_active": {"new_value": True, "previous_value": None}, | ||||||
|  |                 "attributes": {"new_value": {}, "previous_value": None}, | ||||||
|  |                 "date_joined": { | ||||||
|  |                     "new_value": sanitize_item(user.date_joined), | ||||||
|  |                     "previous_value": None, | ||||||
|  |                 }, | ||||||
|  |                 "first_name": {"new_value": "", "previous_value": None}, | ||||||
|  |                 "id": {"new_value": user.pk, "previous_value": None}, | ||||||
|  |                 "last_name": {"new_value": "", "previous_value": None}, | ||||||
|  |                 "password": {"new_value": "********************", "previous_value": None}, | ||||||
|  |                 "password_change_date": { | ||||||
|  |                     "new_value": sanitize_item(user.password_change_date), | ||||||
|  |                     "previous_value": None, | ||||||
|  |                 }, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     @patch( | ||||||
|  |         "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", | ||||||
|  |         PropertyMock(return_value=True), | ||||||
|  |     ) | ||||||
|  |     def test_update(self): | ||||||
|  |         """Test update audit log""" | ||||||
|  |         self.client.force_login(self.user) | ||||||
|  |         user = create_test_admin_user() | ||||||
|  |         current_name = user.name | ||||||
|  |         new_name = generate_id() | ||||||
|  |         response = self.client.patch( | ||||||
|  |             reverse("authentik_api:user-detail", kwargs={"pk": user.id}), | ||||||
|  |             data={"name": new_name}, | ||||||
|  |         ) | ||||||
|  |         user.refresh_from_db() | ||||||
|  |         self.assertEqual(response.status_code, 200) | ||||||
|  |         events = Event.objects.filter( | ||||||
|  |             action=EventAction.MODEL_UPDATED, | ||||||
|  |             context__model__model_name="user", | ||||||
|  |             context__model__app="authentik_core", | ||||||
|  |             context__model__pk=user.pk, | ||||||
|  |         ) | ||||||
|  |         event = events.first() | ||||||
|  |         self.assertIsNotNone(event) | ||||||
|  |         self.assertIsNotNone(event.context["diff"]) | ||||||
|  |         diff = event.context["diff"] | ||||||
|  |         self.assertEqual( | ||||||
|  |             diff, | ||||||
|  |             { | ||||||
|  |                 "name": { | ||||||
|  |                     "new_value": new_name, | ||||||
|  |                     "previous_value": current_name, | ||||||
|  |                 }, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     @patch( | ||||||
|  |         "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", | ||||||
|  |         PropertyMock(return_value=True), | ||||||
|  |     ) | ||||||
|  |     def test_delete(self): | ||||||
|  |         """Test delete audit log""" | ||||||
|  |         self.client.force_login(self.user) | ||||||
|  |         user = create_test_admin_user() | ||||||
|  |         response = self.client.delete( | ||||||
|  |             reverse("authentik_api:user-detail", kwargs={"pk": user.id}), | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 204) | ||||||
|  |         events = Event.objects.filter( | ||||||
|  |             action=EventAction.MODEL_DELETED, | ||||||
|  |             context__model__model_name="user", | ||||||
|  |             context__model__app="authentik_core", | ||||||
|  |             context__model__pk=user.pk, | ||||||
|  |         ) | ||||||
|  |         event = events.first() | ||||||
|  |         self.assertIsNotNone(event) | ||||||
|  |         self.assertNotIn("diff", event.context) | ||||||
|  |  | ||||||
|  |     @patch( | ||||||
|  |         "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", | ||||||
|  |         PropertyMock(return_value=True), | ||||||
|  |     ) | ||||||
|  |     def test_m2m_add(self): | ||||||
|  |         """Test m2m add audit log""" | ||||||
|  |         self.client.force_login(self.user) | ||||||
|  |         user = create_test_admin_user() | ||||||
|  |         group = Group.objects.create(name=generate_id()) | ||||||
|  |         response = self.client.post( | ||||||
|  |             reverse("authentik_api:group-add-user", kwargs={"pk": group.group_uuid}), | ||||||
|  |             data={ | ||||||
|  |                 "pk": user.pk, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 204) | ||||||
|  |         events = Event.objects.filter( | ||||||
|  |             action=EventAction.MODEL_UPDATED, | ||||||
|  |             context__model__model_name="group", | ||||||
|  |             context__model__app="authentik_core", | ||||||
|  |             context__model__pk=group.pk.hex, | ||||||
|  |         ) | ||||||
|  |         event = events.first() | ||||||
|  |         self.assertIsNotNone(event) | ||||||
|  |         self.assertIsNotNone(event.context["diff"]) | ||||||
|  |         diff = event.context["diff"] | ||||||
|  |         self.assertEqual( | ||||||
|  |             diff, | ||||||
|  |             {"users": {"add": [user.pk]}}, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     @patch( | ||||||
|  |         "authentik.enterprise.audit.middleware.EnterpriseAuditMiddleware.enabled", | ||||||
|  |         PropertyMock(return_value=True), | ||||||
|  |     ) | ||||||
|  |     def test_m2m_remove(self): | ||||||
|  |         """Test m2m remove audit log""" | ||||||
|  |         self.client.force_login(self.user) | ||||||
|  |         user = create_test_admin_user() | ||||||
|  |         group = Group.objects.create(name=generate_id()) | ||||||
|  |         response = self.client.post( | ||||||
|  |             reverse("authentik_api:group-remove-user", kwargs={"pk": group.group_uuid}), | ||||||
|  |             data={ | ||||||
|  |                 "pk": user.pk, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 204) | ||||||
|  |         events = Event.objects.filter( | ||||||
|  |             action=EventAction.MODEL_UPDATED, | ||||||
|  |             context__model__model_name="group", | ||||||
|  |             context__model__app="authentik_core", | ||||||
|  |             context__model__pk=group.pk.hex, | ||||||
|  |         ) | ||||||
|  |         event = events.first() | ||||||
|  |         self.assertIsNotNone(event) | ||||||
|  |         self.assertIsNotNone(event.context["diff"]) | ||||||
|  |         diff = event.context["diff"] | ||||||
|  |         self.assertEqual( | ||||||
|  |             diff, | ||||||
|  |             {"users": {"remove": [user.pk]}}, | ||||||
|  |         ) | ||||||
|  | |||||||
| @ -201,10 +201,7 @@ class ConnectionToken(ExpiringModel): | |||||||
|         return settings |         return settings | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return ( |         return f"RAC Connection token {self.session_id} to {self.provider_id}/{self.endpoint_id}" | ||||||
|             f"RAC Connection token {self.session.user} to " |  | ||||||
|             f"{self.endpoint.provider.name}/{self.endpoint.name}" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         verbose_name = _("RAC Connection token") |         verbose_name = _("RAC Connection token") | ||||||
|  | |||||||
| @ -116,12 +116,12 @@ class AuditMiddleware: | |||||||
|             return user |             return user | ||||||
|         user = getattr(request, "user", self.anonymous_user) |         user = getattr(request, "user", self.anonymous_user) | ||||||
|         if not user.is_authenticated: |         if not user.is_authenticated: | ||||||
|  |             self._ensure_fallback_user() | ||||||
|             return self.anonymous_user |             return self.anonymous_user | ||||||
|         return user |         return user | ||||||
|  |  | ||||||
|     def connect(self, request: HttpRequest): |     def connect(self, request: HttpRequest): | ||||||
|         """Connect signal for automatic logging""" |         """Connect signal for automatic logging""" | ||||||
|         self._ensure_fallback_user() |  | ||||||
|         if not hasattr(request, "request_id"): |         if not hasattr(request, "request_id"): | ||||||
|             return |             return | ||||||
|         post_save.connect( |         post_save.connect( | ||||||
| @ -214,7 +214,15 @@ class AuditMiddleware: | |||||||
|             model=model_to_dict(instance), |             model=model_to_dict(instance), | ||||||
|         ).run() |         ).run() | ||||||
|  |  | ||||||
|     def m2m_changed_handler(self, request: HttpRequest, sender, instance: Model, action: str, **_): |     def m2m_changed_handler( | ||||||
|  |         self, | ||||||
|  |         request: HttpRequest, | ||||||
|  |         sender, | ||||||
|  |         instance: Model, | ||||||
|  |         action: str, | ||||||
|  |         thread_kwargs: dict | None = None, | ||||||
|  |         **_, | ||||||
|  |     ): | ||||||
|         """Signal handler for all object's m2m_changed""" |         """Signal handler for all object's m2m_changed""" | ||||||
|         if action not in ["pre_add", "pre_remove", "post_clear"]: |         if action not in ["pre_add", "pre_remove", "post_clear"]: | ||||||
|             return |             return | ||||||
| @ -229,4 +237,5 @@ class AuditMiddleware: | |||||||
|             request, |             request, | ||||||
|             user=user, |             user=user, | ||||||
|             model=model_to_dict(instance), |             model=model_to_dict(instance), | ||||||
|  |             **thread_kwargs, | ||||||
|         ).run() |         ).run() | ||||||
|  | |||||||
| @ -556,7 +556,7 @@ class Notification(SerializerModel): | |||||||
|             if len(self.body) > NOTIFICATION_SUMMARY_LENGTH |             if len(self.body) > NOTIFICATION_SUMMARY_LENGTH | ||||||
|             else self.body |             else self.body | ||||||
|         ) |         ) | ||||||
|         return f"Notification for user {self.user}: {body_trunc}" |         return f"Notification for user {self.user_id}: {body_trunc}" | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         verbose_name = _("Notification") |         verbose_name = _("Notification") | ||||||
|  | |||||||
| @ -119,7 +119,7 @@ class SystemTask(TenantTask): | |||||||
|                 "task_call_kwargs": sanitize_item(kwargs), |                 "task_call_kwargs": sanitize_item(kwargs), | ||||||
|                 "status": self._status, |                 "status": self._status, | ||||||
|                 "messages": sanitize_item(self._messages), |                 "messages": sanitize_item(self._messages), | ||||||
|                 "expires": now() + timedelta(hours=self.result_timeout_hours), |                 "expires": now() + timedelta(hours=self.result_timeout_hours + 3), | ||||||
|                 "expiring": True, |                 "expiring": True, | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  | |||||||
							
								
								
									
										35
									
								
								authentik/events/tests/test_models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								authentik/events/tests/test_models.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,35 @@ | |||||||
|  | """authentik event models tests""" | ||||||
|  |  | ||||||
|  | from collections.abc import Callable | ||||||
|  |  | ||||||
|  | from django.db.models import Model | ||||||
|  | from django.test import TestCase | ||||||
|  |  | ||||||
|  | from authentik.core.models import default_token_key | ||||||
|  | from authentik.lib.utils.reflection import get_apps | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestModels(TestCase): | ||||||
|  |     """Test Models""" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def model_tester_factory(test_model: type[Model]) -> Callable: | ||||||
|  |     """Test models' __str__ and __repr__""" | ||||||
|  |  | ||||||
|  |     def tester(self: TestModels): | ||||||
|  |         allowed = 0 | ||||||
|  |         # Token-like objects need to lookup the current tenant to get the default token length | ||||||
|  |         for field in test_model._meta.fields: | ||||||
|  |             if field.default == default_token_key: | ||||||
|  |                 allowed += 1 | ||||||
|  |         with self.assertNumQueries(allowed): | ||||||
|  |             str(test_model()) | ||||||
|  |         with self.assertNumQueries(allowed): | ||||||
|  |             repr(test_model()) | ||||||
|  |  | ||||||
|  |     return tester | ||||||
|  |  | ||||||
|  |  | ||||||
|  | for app in get_apps(): | ||||||
|  |     for model in app.get_models(): | ||||||
|  |         setattr(TestModels, f"test_{app.label}_{model.__name__}", model_tester_factory(model)) | ||||||
| @ -33,6 +33,7 @@ from authentik.lib.utils.file import ( | |||||||
| ) | ) | ||||||
| from authentik.lib.views import bad_request_message | from authentik.lib.views import bad_request_message | ||||||
| from authentik.rbac.decorators import permission_required | from authentik.rbac.decorators import permission_required | ||||||
|  | from authentik.rbac.filters import ObjectFilter | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| @ -277,8 +278,8 @@ class FlowViewSet(UsedByMixin, ModelViewSet): | |||||||
|             400: OpenApiResponse(description="Flow not applicable"), |             400: OpenApiResponse(description="Flow not applicable"), | ||||||
|         }, |         }, | ||||||
|     ) |     ) | ||||||
|     @action(detail=True, pagination_class=None, filter_backends=[]) |     @action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) | ||||||
|     def execute(self, request: Request, _slug: str): |     def execute(self, request: Request, slug: str): | ||||||
|         """Execute flow for current user""" |         """Execute flow for current user""" | ||||||
|         # Because we pre-plan the flow here, and not in the planner, we need to manually clear |         # Because we pre-plan the flow here, and not in the planner, we need to manually clear | ||||||
|         # the history of the inspector |         # the history of the inspector | ||||||
|  | |||||||
| @ -203,6 +203,7 @@ class FlowPlanner: | |||||||
|                 "f(plan): building plan", |                 "f(plan): building plan", | ||||||
|             ) |             ) | ||||||
|             plan = self._build_plan(user, request, default_context) |             plan = self._build_plan(user, request, default_context) | ||||||
|  |             if self.use_cache: | ||||||
|                 cache.set(cache_key(self.flow, user), plan, CACHE_TIMEOUT) |                 cache.set(cache_key(self.flow, user), plan, CACHE_TIMEOUT) | ||||||
|             if not plan.bindings and not self.allow_empty_flows: |             if not plan.bindings and not self.allow_empty_flows: | ||||||
|                 raise EmptyFlowException() |                 raise EmptyFlowException() | ||||||
|  | |||||||
| @ -6,6 +6,7 @@ from rest_framework.test import APITestCase | |||||||
| from authentik.core.tests.utils import create_test_admin_user | from authentik.core.tests.utils import create_test_admin_user | ||||||
| from authentik.flows.api.stages import StageSerializer, StageViewSet | from authentik.flows.api.stages import StageSerializer, StageViewSet | ||||||
| from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding, Stage | from authentik.flows.models import Flow, FlowDesignation, FlowStageBinding, Stage | ||||||
|  | from authentik.lib.generators import generate_id | ||||||
| from authentik.policies.dummy.models import DummyPolicy | from authentik.policies.dummy.models import DummyPolicy | ||||||
| from authentik.policies.models import PolicyBinding | from authentik.policies.models import PolicyBinding | ||||||
| from authentik.stages.dummy.models import DummyStage | from authentik.stages.dummy.models import DummyStage | ||||||
| @ -101,3 +102,21 @@ class TestFlowsAPI(APITestCase): | |||||||
|             reverse("authentik_api:stage-types"), |             reverse("authentik_api:stage-types"), | ||||||
|         ) |         ) | ||||||
|         self.assertEqual(response.status_code, 200) |         self.assertEqual(response.status_code, 200) | ||||||
|  |  | ||||||
|  |     def test_execute(self): | ||||||
|  |         """Test execute endpoint""" | ||||||
|  |         user = create_test_admin_user() | ||||||
|  |         self.client.force_login(user) | ||||||
|  |  | ||||||
|  |         flow = Flow.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             slug=generate_id(), | ||||||
|  |             designation=FlowDesignation.AUTHENTICATION, | ||||||
|  |         ) | ||||||
|  |         FlowStageBinding.objects.create( | ||||||
|  |             target=flow, stage=DummyStage.objects.create(name=generate_id()), order=0 | ||||||
|  |         ) | ||||||
|  |         response = self.client.get( | ||||||
|  |             reverse("authentik_api:flow-execute", kwargs={"slug": flow.slug}) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(response.status_code, 200) | ||||||
|  | |||||||
| @ -53,6 +53,7 @@ cache: | |||||||
|  |  | ||||||
| # result_backend: | # result_backend: | ||||||
| #   url: "" | #   url: "" | ||||||
|  | #   transport_options: "" | ||||||
|  |  | ||||||
| debug: false | debug: false | ||||||
| remote_debug: false | remote_debug: false | ||||||
|  | |||||||
| @ -23,6 +23,7 @@ from authentik.outposts.models import ( | |||||||
|     KubernetesServiceConnection, |     KubernetesServiceConnection, | ||||||
|     OutpostServiceConnection, |     OutpostServiceConnection, | ||||||
| ) | ) | ||||||
|  | from authentik.rbac.filters import ObjectFilter | ||||||
|  |  | ||||||
|  |  | ||||||
| class ServiceConnectionSerializer(ModelSerializer, MetaNameSerializer): | class ServiceConnectionSerializer(ModelSerializer, MetaNameSerializer): | ||||||
| @ -88,7 +89,7 @@ class ServiceConnectionViewSet( | |||||||
|         return Response(TypeCreateSerializer(data, many=True).data) |         return Response(TypeCreateSerializer(data, many=True).data) | ||||||
|  |  | ||||||
|     @extend_schema(responses={200: ServiceConnectionStateSerializer(many=False)}) |     @extend_schema(responses={200: ServiceConnectionStateSerializer(many=False)}) | ||||||
|     @action(detail=True, pagination_class=None, filter_backends=[]) |     @action(detail=True, pagination_class=None, filter_backends=[ObjectFilter]) | ||||||
|     def state(self, request: Request, pk: str) -> Response: |     def state(self, request: Request, pk: str) -> Response: | ||||||
|         """Get the service connection's state""" |         """Get the service connection's state""" | ||||||
|         connection = self.get_object() |         connection = self.get_object() | ||||||
|  | |||||||
| @ -326,7 +326,7 @@ class AuthorizationCode(SerializerModel, ExpiringModel, BaseGrantModel): | |||||||
|         verbose_name_plural = _("Authorization Codes") |         verbose_name_plural = _("Authorization Codes") | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return f"Authorization code for {self.provider} for user {self.user}" |         return f"Authorization code for {self.provider_id} for user {self.user_id}" | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def serializer(self) -> Serializer: |     def serializer(self) -> Serializer: | ||||||
| @ -356,7 +356,7 @@ class AccessToken(SerializerModel, ExpiringModel, BaseGrantModel): | |||||||
|         verbose_name_plural = _("OAuth2 Access Tokens") |         verbose_name_plural = _("OAuth2 Access Tokens") | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return f"Access Token for {self.provider} for user {self.user}" |         return f"Access Token for {self.provider_id} for user {self.user_id}" | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def id_token(self) -> IDToken: |     def id_token(self) -> IDToken: | ||||||
| @ -399,7 +399,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel): | |||||||
|         verbose_name_plural = _("OAuth2 Refresh Tokens") |         verbose_name_plural = _("OAuth2 Refresh Tokens") | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return f"Refresh Token for {self.provider} for user {self.user}" |         return f"Refresh Token for {self.provider_id} for user {self.user_id}" | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def id_token(self) -> IDToken: |     def id_token(self) -> IDToken: | ||||||
| @ -443,4 +443,4 @@ class DeviceToken(ExpiringModel): | |||||||
|         verbose_name_plural = _("Device Tokens") |         verbose_name_plural = _("Device Tokens") | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return f"Device Token for {self.provider}" |         return f"Device Token for {self.provider_id}" | ||||||
|  | |||||||
| @ -4,9 +4,10 @@ from urllib.parse import urlencode | |||||||
|  |  | ||||||
| from django.urls import reverse | from django.urls import reverse | ||||||
|  |  | ||||||
| from authentik.core.models import Application | from authentik.core.models import Application, Group | ||||||
| from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow | from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
|  | from authentik.policies.models import PolicyBinding | ||||||
| from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | ||||||
| from authentik.providers.oauth2.tests.utils import OAuthTestCase | from authentik.providers.oauth2.tests.utils import OAuthTestCase | ||||||
| from authentik.providers.oauth2.views.device_init import QS_KEY_CODE | from authentik.providers.oauth2.views.device_init import QS_KEY_CODE | ||||||
| @ -77,3 +78,23 @@ class TesOAuth2DeviceInit(OAuthTestCase): | |||||||
|             + "?" |             + "?" | ||||||
|             + urlencode({QS_KEY_CODE: token.user_code}), |             + urlencode({QS_KEY_CODE: token.user_code}), | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_device_init_denied(self): | ||||||
|  |         """Test device init""" | ||||||
|  |         group = Group.objects.create(name="foo") | ||||||
|  |         PolicyBinding.objects.create( | ||||||
|  |             group=group, | ||||||
|  |             target=self.application, | ||||||
|  |             order=0, | ||||||
|  |         ) | ||||||
|  |         token = DeviceToken.objects.create( | ||||||
|  |             user_code="foo", | ||||||
|  |             provider=self.provider, | ||||||
|  |         ) | ||||||
|  |         res = self.client.get( | ||||||
|  |             reverse("authentik_providers_oauth2_root:device-login") | ||||||
|  |             + "?" | ||||||
|  |             + urlencode({QS_KEY_CODE: token.user_code}) | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(res.status_code, 200) | ||||||
|  |         self.assertIn(b"Permission denied", res.content) | ||||||
|  | |||||||
| @ -10,6 +10,7 @@ from jwt import PyJWKSet | |||||||
|  |  | ||||||
| from authentik.core.models import Application | from authentik.core.models import Application | ||||||
| from authentik.core.tests.utils import create_test_cert, create_test_flow | from authentik.core.tests.utils import create_test_cert, create_test_flow | ||||||
|  | from authentik.crypto.builder import PrivateKeyAlg | ||||||
| from authentik.crypto.models import CertificateKeyPair | from authentik.crypto.models import CertificateKeyPair | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.providers.oauth2.models import OAuth2Provider | from authentik.providers.oauth2.models import OAuth2Provider | ||||||
| @ -82,7 +83,7 @@ class TestJWKS(OAuthTestCase): | |||||||
|             client_id="test", |             client_id="test", | ||||||
|             authorization_flow=create_test_flow(), |             authorization_flow=create_test_flow(), | ||||||
|             redirect_uris="http://local.invalid", |             redirect_uris="http://local.invalid", | ||||||
|             signing_key=create_test_cert(use_ec_private_key=True), |             signing_key=create_test_cert(PrivateKeyAlg.ECDSA), | ||||||
|         ) |         ) | ||||||
|         app = Application.objects.create(name="test", slug="test", provider=provider) |         app = Application.objects.create(name="test", slug="test", provider=provider) | ||||||
|         response = self.client.get( |         response = self.client.get( | ||||||
|  | |||||||
| @ -11,10 +11,11 @@ from django.views.decorators.csrf import csrf_exempt | |||||||
| from rest_framework.throttling import AnonRateThrottle | from rest_framework.throttling import AnonRateThrottle | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
|  | from authentik.core.models import Application | ||||||
| from authentik.lib.config import CONFIG | from authentik.lib.config import CONFIG | ||||||
| from authentik.lib.utils.time import timedelta_from_string | from authentik.lib.utils.time import timedelta_from_string | ||||||
| from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | ||||||
| from authentik.providers.oauth2.views.device_init import QS_KEY_CODE, get_application | from authentik.providers.oauth2.views.device_init import QS_KEY_CODE | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| @ -37,7 +38,9 @@ class DeviceView(View): | |||||||
|         ).first() |         ).first() | ||||||
|         if not provider: |         if not provider: | ||||||
|             return HttpResponseBadRequest() |             return HttpResponseBadRequest() | ||||||
|         if not get_application(provider): |         try: | ||||||
|  |             _ = provider.application | ||||||
|  |         except Application.DoesNotExist: | ||||||
|             return HttpResponseBadRequest() |             return HttpResponseBadRequest() | ||||||
|         self.provider = provider |         self.provider = provider | ||||||
|         self.client_id = client_id |         self.client_id = client_id | ||||||
|  | |||||||
| @ -1,8 +1,9 @@ | |||||||
| """Device flow views""" | """Device flow views""" | ||||||
|  |  | ||||||
|  | from typing import Any | ||||||
|  |  | ||||||
| from django.http import HttpRequest, HttpResponse | from django.http import HttpRequest, HttpResponse | ||||||
| from django.utils.translation import gettext as _ | from django.utils.translation import gettext as _ | ||||||
| from django.views import View |  | ||||||
| from rest_framework.exceptions import ValidationError | from rest_framework.exceptions import ValidationError | ||||||
| from rest_framework.fields import CharField, IntegerField | from rest_framework.fields import CharField, IntegerField | ||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
| @ -16,7 +17,8 @@ from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO, | |||||||
| from authentik.flows.stage import ChallengeStageView | from authentik.flows.stage import ChallengeStageView | ||||||
| from authentik.flows.views.executor import SESSION_KEY_PLAN | from authentik.flows.views.executor import SESSION_KEY_PLAN | ||||||
| from authentik.lib.utils.urls import redirect_with_qs | from authentik.lib.utils.urls import redirect_with_qs | ||||||
| from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider | from authentik.policies.views import PolicyAccessView | ||||||
|  | from authentik.providers.oauth2.models import DeviceToken | ||||||
| from authentik.providers.oauth2.views.device_finish import ( | from authentik.providers.oauth2.views.device_finish import ( | ||||||
|     PLAN_CONTEXT_DEVICE, |     PLAN_CONTEXT_DEVICE, | ||||||
|     OAuthDeviceCodeFinishStage, |     OAuthDeviceCodeFinishStage, | ||||||
| @ -31,31 +33,23 @@ LOGGER = get_logger() | |||||||
| QS_KEY_CODE = "code"  # nosec | QS_KEY_CODE = "code"  # nosec | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_application(provider: OAuth2Provider) -> Application | None: | class CodeValidatorView(PolicyAccessView): | ||||||
|     """Get application from provider""" |     """Helper to validate frontside token""" | ||||||
|     try: |  | ||||||
|         app = provider.application |  | ||||||
|         if not app: |  | ||||||
|             return None |  | ||||||
|         return app |  | ||||||
|     except Application.DoesNotExist: |  | ||||||
|         return None |  | ||||||
|  |  | ||||||
|  |     def __init__(self, code: str, **kwargs: Any) -> None: | ||||||
|  |         super().__init__(**kwargs) | ||||||
|  |         self.code = code | ||||||
|  |  | ||||||
| def validate_code(code: int, request: HttpRequest) -> HttpResponse | None: |     def resolve_provider_application(self): | ||||||
|     """Validate user token""" |         self.token = DeviceToken.objects.filter(user_code=self.code).first() | ||||||
|     token = DeviceToken.objects.filter( |         if not self.token: | ||||||
|         user_code=code, |             raise Application.DoesNotExist | ||||||
|     ).first() |         self.provider = self.token.provider | ||||||
|     if not token: |         self.application = self.token.provider.application | ||||||
|         return None |  | ||||||
|  |  | ||||||
|     app = get_application(token.provider) |     def get(self, request: HttpRequest, *args, **kwargs): | ||||||
|     if not app: |         scope_descriptions = UserInfoView().get_scope_descriptions(self.token.scope, self.provider) | ||||||
|         return None |         planner = FlowPlanner(self.provider.authorization_flow) | ||||||
|  |  | ||||||
|     scope_descriptions = UserInfoView().get_scope_descriptions(token.scope, token.provider) |  | ||||||
|     planner = FlowPlanner(token.provider.authorization_flow) |  | ||||||
|         planner.allow_empty_flows = True |         planner.allow_empty_flows = True | ||||||
|         planner.use_cache = False |         planner.use_cache = False | ||||||
|         try: |         try: | ||||||
| @ -63,12 +57,12 @@ def validate_code(code: int, request: HttpRequest) -> HttpResponse | None: | |||||||
|                 request, |                 request, | ||||||
|                 { |                 { | ||||||
|                     PLAN_CONTEXT_SSO: True, |                     PLAN_CONTEXT_SSO: True, | ||||||
|                 PLAN_CONTEXT_APPLICATION: app, |                     PLAN_CONTEXT_APPLICATION: self.application, | ||||||
|                     # OAuth2 related params |                     # OAuth2 related params | ||||||
|                 PLAN_CONTEXT_DEVICE: token, |                     PLAN_CONTEXT_DEVICE: self.token, | ||||||
|                     # Consent related params |                     # Consent related params | ||||||
|                     PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.") |                     PLAN_CONTEXT_CONSENT_HEADER: _("You're about to sign into %(application)s.") | ||||||
|                 % {"application": app.name}, |                     % {"application": self.application.name}, | ||||||
|                     PLAN_CONTEXT_CONSENT_PERMISSIONS: scope_descriptions, |                     PLAN_CONTEXT_CONSENT_PERMISSIONS: scope_descriptions, | ||||||
|                 }, |                 }, | ||||||
|             ) |             ) | ||||||
| @ -80,11 +74,11 @@ def validate_code(code: int, request: HttpRequest) -> HttpResponse | None: | |||||||
|         return redirect_with_qs( |         return redirect_with_qs( | ||||||
|             "authentik_core:if-flow", |             "authentik_core:if-flow", | ||||||
|             request.GET, |             request.GET, | ||||||
|         flow_slug=token.provider.authorization_flow.slug, |             flow_slug=self.token.provider.authorization_flow.slug, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class DeviceEntryView(View): | class DeviceEntryView(PolicyAccessView): | ||||||
|     """View used to initiate the device-code flow, url entered by endusers""" |     """View used to initiate the device-code flow, url entered by endusers""" | ||||||
|  |  | ||||||
|     def dispatch(self, request: HttpRequest) -> HttpResponse: |     def dispatch(self, request: HttpRequest) -> HttpResponse: | ||||||
| @ -94,7 +88,9 @@ class DeviceEntryView(View): | |||||||
|             LOGGER.info("Brand has no device code flow configured", brand=brand) |             LOGGER.info("Brand has no device code flow configured", brand=brand) | ||||||
|             return HttpResponse(status=404) |             return HttpResponse(status=404) | ||||||
|         if QS_KEY_CODE in request.GET: |         if QS_KEY_CODE in request.GET: | ||||||
|             validation = validate_code(request.GET[QS_KEY_CODE], request) |             validation = CodeValidatorView(request.GET[QS_KEY_CODE], request=request).dispatch( | ||||||
|  |                 request | ||||||
|  |             ) | ||||||
|             if validation: |             if validation: | ||||||
|                 return validation |                 return validation | ||||||
|             LOGGER.info("Got code from query parameter but no matching token found") |             LOGGER.info("Got code from query parameter but no matching token found") | ||||||
| @ -131,7 +127,7 @@ class OAuthDeviceCodeChallengeResponse(ChallengeResponse): | |||||||
|  |  | ||||||
|     def validate_code(self, code: int) -> HttpResponse | None: |     def validate_code(self, code: int) -> HttpResponse | None: | ||||||
|         """Validate code and save the returned http response""" |         """Validate code and save the returned http response""" | ||||||
|         response = validate_code(code, self.stage.request) |         response = CodeValidatorView(code, request=self.stage.request).dispatch(self.stage.request) | ||||||
|         if not response: |         if not response: | ||||||
|             raise ValidationError(_("Invalid code"), "invalid") |             raise ValidationError(_("Invalid code"), "invalid") | ||||||
|         return response |         return response | ||||||
|  | |||||||
| @ -0,0 +1,44 @@ | |||||||
|  | # Generated by Django 5.0.4 on 2024-05-01 15:32 | ||||||
|  |  | ||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_providers_saml", "0013_samlprovider_default_relay_state"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.AlterField( | ||||||
|  |             model_name="samlprovider", | ||||||
|  |             name="digest_algorithm", | ||||||
|  |             field=models.TextField( | ||||||
|  |                 choices=[ | ||||||
|  |                     ("http://www.w3.org/2000/09/xmldsig#sha1", "SHA1"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmlenc#sha256", "SHA256"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmldsig-more#sha384", "SHA384"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmlenc#sha512", "SHA512"), | ||||||
|  |                 ], | ||||||
|  |                 default="http://www.w3.org/2001/04/xmlenc#sha256", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |         migrations.AlterField( | ||||||
|  |             model_name="samlprovider", | ||||||
|  |             name="signature_algorithm", | ||||||
|  |             field=models.TextField( | ||||||
|  |                 choices=[ | ||||||
|  |                     ("http://www.w3.org/2000/09/xmldsig#rsa-sha1", "RSA-SHA1"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmldsig-more#rsa-sha256", "RSA-SHA256"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmldsig-more#rsa-sha384", "RSA-SHA384"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmldsig-more#rsa-sha512", "RSA-SHA512"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1", "ECDSA-SHA1"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256", "ECDSA-SHA256"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384", "ECDSA-SHA384"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512", "ECDSA-SHA512"), | ||||||
|  |                     ("http://www.w3.org/2000/09/xmldsig#dsa-sha1", "DSA-SHA1"), | ||||||
|  |                 ], | ||||||
|  |                 default="http://www.w3.org/2001/04/xmldsig-more#rsa-sha256", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -11,6 +11,10 @@ from authentik.crypto.models import CertificateKeyPair | |||||||
| from authentik.lib.utils.time import timedelta_string_validator | from authentik.lib.utils.time import timedelta_string_validator | ||||||
| from authentik.sources.saml.processors.constants import ( | from authentik.sources.saml.processors.constants import ( | ||||||
|     DSA_SHA1, |     DSA_SHA1, | ||||||
|  |     ECDSA_SHA1, | ||||||
|  |     ECDSA_SHA256, | ||||||
|  |     ECDSA_SHA384, | ||||||
|  |     ECDSA_SHA512, | ||||||
|     RSA_SHA1, |     RSA_SHA1, | ||||||
|     RSA_SHA256, |     RSA_SHA256, | ||||||
|     RSA_SHA384, |     RSA_SHA384, | ||||||
| @ -92,8 +96,7 @@ class SAMLProvider(Provider): | |||||||
|         ), |         ), | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     digest_algorithm = models.CharField( |     digest_algorithm = models.TextField( | ||||||
|         max_length=50, |  | ||||||
|         choices=( |         choices=( | ||||||
|             (SHA1, _("SHA1")), |             (SHA1, _("SHA1")), | ||||||
|             (SHA256, _("SHA256")), |             (SHA256, _("SHA256")), | ||||||
| @ -102,13 +105,16 @@ class SAMLProvider(Provider): | |||||||
|         ), |         ), | ||||||
|         default=SHA256, |         default=SHA256, | ||||||
|     ) |     ) | ||||||
|     signature_algorithm = models.CharField( |     signature_algorithm = models.TextField( | ||||||
|         max_length=50, |  | ||||||
|         choices=( |         choices=( | ||||||
|             (RSA_SHA1, _("RSA-SHA1")), |             (RSA_SHA1, _("RSA-SHA1")), | ||||||
|             (RSA_SHA256, _("RSA-SHA256")), |             (RSA_SHA256, _("RSA-SHA256")), | ||||||
|             (RSA_SHA384, _("RSA-SHA384")), |             (RSA_SHA384, _("RSA-SHA384")), | ||||||
|             (RSA_SHA512, _("RSA-SHA512")), |             (RSA_SHA512, _("RSA-SHA512")), | ||||||
|  |             (ECDSA_SHA1, _("ECDSA-SHA1")), | ||||||
|  |             (ECDSA_SHA256, _("ECDSA-SHA256")), | ||||||
|  |             (ECDSA_SHA384, _("ECDSA-SHA384")), | ||||||
|  |             (ECDSA_SHA512, _("ECDSA-SHA512")), | ||||||
|             (DSA_SHA1, _("DSA-SHA1")), |             (DSA_SHA1, _("DSA-SHA1")), | ||||||
|         ), |         ), | ||||||
|         default=RSA_SHA256, |         default=RSA_SHA256, | ||||||
|  | |||||||
| @ -7,13 +7,14 @@ from lxml import etree  # nosec | |||||||
|  |  | ||||||
| from authentik.core.models import Application | from authentik.core.models import Application | ||||||
| from authentik.core.tests.utils import create_test_cert, create_test_flow | from authentik.core.tests.utils import create_test_cert, create_test_flow | ||||||
|  | from authentik.crypto.builder import PrivateKeyAlg | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.lib.tests.utils import load_fixture | from authentik.lib.tests.utils import load_fixture | ||||||
| from authentik.lib.xml import lxml_from_string | from authentik.lib.xml import lxml_from_string | ||||||
| from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping, SAMLProvider | from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping, SAMLProvider | ||||||
| from authentik.providers.saml.processors.metadata import MetadataProcessor | from authentik.providers.saml.processors.metadata import MetadataProcessor | ||||||
| from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser | from authentik.providers.saml.processors.metadata_parser import ServiceProviderMetadataParser | ||||||
| from authentik.sources.saml.processors.constants import NS_MAP, NS_SAML_METADATA | from authentik.sources.saml.processors.constants import ECDSA_SHA256, NS_MAP, NS_SAML_METADATA | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestServiceProviderMetadataParser(TestCase): | class TestServiceProviderMetadataParser(TestCase): | ||||||
| @ -107,12 +108,41 @@ class TestServiceProviderMetadataParser(TestCase): | |||||||
|                 load_fixture("fixtures/cert.xml").replace("/apps/user_saml", "") |                 load_fixture("fixtures/cert.xml").replace("/apps/user_saml", "") | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def test_signature(self): |     def test_signature_rsa(self): | ||||||
|         """Test signature validation""" |         """Test signature validation (RSA)""" | ||||||
|         provider = SAMLProvider.objects.create( |         provider = SAMLProvider.objects.create( | ||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|             authorization_flow=self.flow, |             authorization_flow=self.flow, | ||||||
|             signing_kp=create_test_cert(), |             signing_kp=create_test_cert(PrivateKeyAlg.RSA), | ||||||
|  |         ) | ||||||
|  |         Application.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             slug=generate_id(), | ||||||
|  |             provider=provider, | ||||||
|  |         ) | ||||||
|  |         request = self.factory.get("/") | ||||||
|  |         metadata = MetadataProcessor(provider, request).build_entity_descriptor() | ||||||
|  |  | ||||||
|  |         root = fromstring(metadata.encode()) | ||||||
|  |         xmlsec.tree.add_ids(root, ["ID"]) | ||||||
|  |         signature_nodes = root.xpath("/md:EntityDescriptor/ds:Signature", namespaces=NS_MAP) | ||||||
|  |         signature_node = signature_nodes[0] | ||||||
|  |         ctx = xmlsec.SignatureContext() | ||||||
|  |         key = xmlsec.Key.from_memory( | ||||||
|  |             provider.signing_kp.certificate_data, | ||||||
|  |             xmlsec.constants.KeyDataFormatCertPem, | ||||||
|  |             None, | ||||||
|  |         ) | ||||||
|  |         ctx.key = key | ||||||
|  |         ctx.verify(signature_node) | ||||||
|  |  | ||||||
|  |     def test_signature_ecdsa(self): | ||||||
|  |         """Test signature validation (ECDSA)""" | ||||||
|  |         provider = SAMLProvider.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             authorization_flow=self.flow, | ||||||
|  |             signing_kp=create_test_cert(PrivateKeyAlg.ECDSA), | ||||||
|  |             signature_algorithm=ECDSA_SHA256, | ||||||
|         ) |         ) | ||||||
|         Application.objects.create( |         Application.objects.create( | ||||||
|             name=generate_id(), |             name=generate_id(), | ||||||
|  | |||||||
| @ -41,7 +41,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): | |||||||
|         if not scim_group: |         if not scim_group: | ||||||
|             self.logger.debug("Group does not exist in SCIM, skipping") |             self.logger.debug("Group does not exist in SCIM, skipping") | ||||||
|             return None |             return None | ||||||
|         response = self._request("DELETE", f"/Groups/{scim_group.id}") |         response = self._request("DELETE", f"/Groups/{scim_group.scim_id}") | ||||||
|         scim_group.delete() |         scim_group.delete() | ||||||
|         return response |         return response | ||||||
|  |  | ||||||
| @ -89,7 +89,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): | |||||||
|         for user in connections: |         for user in connections: | ||||||
|             members.append( |             members.append( | ||||||
|                 GroupMember( |                 GroupMember( | ||||||
|                     value=user.id, |                     value=user.scim_id, | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|         if members: |         if members: | ||||||
| @ -107,16 +107,19 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): | |||||||
|                 exclude_unset=True, |                 exclude_unset=True, | ||||||
|             ), |             ), | ||||||
|         ) |         ) | ||||||
|         SCIMGroup.objects.create(provider=self.provider, group=group, id=response["id"]) |         scim_id = response.get("id") | ||||||
|  |         if not scim_id or scim_id == "": | ||||||
|  |             raise StopSync("SCIM Response with missing or invalid `id`") | ||||||
|  |         SCIMGroup.objects.create(provider=self.provider, group=group, scim_id=scim_id) | ||||||
|  |  | ||||||
|     def _update(self, group: Group, connection: SCIMGroup): |     def _update(self, group: Group, connection: SCIMGroup): | ||||||
|         """Update existing group""" |         """Update existing group""" | ||||||
|         scim_group = self.to_scim(group) |         scim_group = self.to_scim(group) | ||||||
|         scim_group.id = connection.id |         scim_group.id = connection.scim_id | ||||||
|         try: |         try: | ||||||
|             return self._request( |             return self._request( | ||||||
|                 "PUT", |                 "PUT", | ||||||
|                 f"/Groups/{scim_group.id}", |                 f"/Groups/{connection.scim_id}", | ||||||
|                 json=scim_group.model_dump( |                 json=scim_group.model_dump( | ||||||
|                     mode="json", |                     mode="json", | ||||||
|                     exclude_unset=True, |                     exclude_unset=True, | ||||||
| @ -185,13 +188,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): | |||||||
|             return |             return | ||||||
|         user_ids = list( |         user_ids = list( | ||||||
|             SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list( |             SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list( | ||||||
|                 "id", flat=True |                 "scim_id", flat=True | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|         if len(user_ids) < 1: |         if len(user_ids) < 1: | ||||||
|             return |             return | ||||||
|         self._patch( |         self._patch( | ||||||
|             scim_group.id, |             scim_group.scim_id, | ||||||
|             PatchOperation( |             PatchOperation( | ||||||
|                 op=PatchOp.add, |                 op=PatchOp.add, | ||||||
|                 path="members", |                 path="members", | ||||||
| @ -211,13 +214,13 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]): | |||||||
|             return |             return | ||||||
|         user_ids = list( |         user_ids = list( | ||||||
|             SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list( |             SCIMUser.objects.filter(user__pk__in=users_set, provider=self.provider).values_list( | ||||||
|                 "id", flat=True |                 "scim_id", flat=True | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|         if len(user_ids) < 1: |         if len(user_ids) < 1: | ||||||
|             return |             return | ||||||
|         self._patch( |         self._patch( | ||||||
|             scim_group.id, |             scim_group.scim_id, | ||||||
|             PatchOperation( |             PatchOperation( | ||||||
|                 op=PatchOp.remove, |                 op=PatchOp.remove, | ||||||
|                 path="members", |                 path="members", | ||||||
|  | |||||||
| @ -9,13 +9,14 @@ from pydanticscim.service_provider import ( | |||||||
| ) | ) | ||||||
| from pydanticscim.user import User as BaseUser | from pydanticscim.user import User as BaseUser | ||||||
|  |  | ||||||
|  | SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User" | ||||||
|  | SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group" | ||||||
|  |  | ||||||
|  |  | ||||||
| class User(BaseUser): | class User(BaseUser): | ||||||
|     """Modified User schema with added externalId field""" |     """Modified User schema with added externalId field""" | ||||||
|  |  | ||||||
|     schemas: list[str] = [ |     schemas: list[str] = [SCIM_USER_SCHEMA] | ||||||
|         "urn:ietf:params:scim:schemas:core:2.0:User", |  | ||||||
|     ] |  | ||||||
|     externalId: str | None = None |     externalId: str | None = None | ||||||
|     meta: dict | None = None |     meta: dict | None = None | ||||||
|  |  | ||||||
| @ -23,9 +24,7 @@ class User(BaseUser): | |||||||
| class Group(BaseGroup): | class Group(BaseGroup): | ||||||
|     """Modified Group schema with added externalId field""" |     """Modified Group schema with added externalId field""" | ||||||
|  |  | ||||||
|     schemas: list[str] = [ |     schemas: list[str] = [SCIM_GROUP_SCHEMA] | ||||||
|         "urn:ietf:params:scim:schemas:core:2.0:Group", |  | ||||||
|     ] |  | ||||||
|     externalId: str | None = None |     externalId: str | None = None | ||||||
|     meta: dict | None = None |     meta: dict | None = None | ||||||
|  |  | ||||||
|  | |||||||
| @ -34,7 +34,7 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]): | |||||||
|         if not scim_user: |         if not scim_user: | ||||||
|             self.logger.debug("User does not exist in SCIM, skipping") |             self.logger.debug("User does not exist in SCIM, skipping") | ||||||
|             return None |             return None | ||||||
|         response = self._request("DELETE", f"/Users/{scim_user.id}") |         response = self._request("DELETE", f"/Users/{scim_user.scim_id}") | ||||||
|         scim_user.delete() |         scim_user.delete() | ||||||
|         return response |         return response | ||||||
|  |  | ||||||
| @ -85,15 +85,18 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]): | |||||||
|                 exclude_unset=True, |                 exclude_unset=True, | ||||||
|             ), |             ), | ||||||
|         ) |         ) | ||||||
|         SCIMUser.objects.create(provider=self.provider, user=user, id=response["id"]) |         scim_id = response.get("id") | ||||||
|  |         if not scim_id or scim_id == "": | ||||||
|  |             raise StopSync("SCIM Response with missing or invalid `id`") | ||||||
|  |         SCIMUser.objects.create(provider=self.provider, user=user, scim_id=scim_id) | ||||||
|  |  | ||||||
|     def _update(self, user: User, connection: SCIMUser): |     def _update(self, user: User, connection: SCIMUser): | ||||||
|         """Update existing user""" |         """Update existing user""" | ||||||
|         scim_user = self.to_scim(user) |         scim_user = self.to_scim(user) | ||||||
|         scim_user.id = connection.id |         scim_user.id = connection.scim_id | ||||||
|         self._request( |         self._request( | ||||||
|             "PUT", |             "PUT", | ||||||
|             f"/Users/{connection.id}", |             f"/Users/{connection.scim_id}", | ||||||
|             json=scim_user.model_dump( |             json=scim_user.model_dump( | ||||||
|                 mode="json", |                 mode="json", | ||||||
|                 exclude_unset=True, |                 exclude_unset=True, | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ | |||||||
| from structlog.stdlib import get_logger | from structlog.stdlib import get_logger | ||||||
|  |  | ||||||
| from authentik.providers.scim.models import SCIMProvider | from authentik.providers.scim.models import SCIMProvider | ||||||
| from authentik.providers.scim.tasks import scim_sync | from authentik.providers.scim.tasks import scim_task_wrapper | ||||||
| from authentik.tenants.management import TenantCommand | from authentik.tenants.management import TenantCommand | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
| @ -21,4 +21,4 @@ class Command(TenantCommand): | |||||||
|             if not provider: |             if not provider: | ||||||
|                 LOGGER.warning("Provider does not exist", name=provider_name) |                 LOGGER.warning("Provider does not exist", name=provider_name) | ||||||
|                 continue |                 continue | ||||||
|             scim_sync.delay(provider.pk).get() |             scim_task_wrapper(provider.pk).get() | ||||||
|  | |||||||
| @ -0,0 +1,76 @@ | |||||||
|  | # Generated by Django 5.0.4 on 2024-05-03 12:38 | ||||||
|  |  | ||||||
|  | import uuid | ||||||
|  | from django.db import migrations, models | ||||||
|  | from django.apps.registry import Apps | ||||||
|  |  | ||||||
|  | from django.db.backends.base.schema import BaseDatabaseSchemaEditor | ||||||
|  |  | ||||||
|  | from authentik.lib.migrations import progress_bar | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def fix_scim_user_group_pk(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): | ||||||
|  |     SCIMUser = apps.get_model("authentik_providers_scim", "SCIMUser") | ||||||
|  |     SCIMGroup = apps.get_model("authentik_providers_scim", "SCIMGroup") | ||||||
|  |     db_alias = schema_editor.connection.alias | ||||||
|  |     print("\nFixing primary key for SCIM users, this might take a couple of minutes...") | ||||||
|  |     for user in progress_bar(SCIMUser.objects.using(db_alias).all()): | ||||||
|  |         SCIMUser.objects.using(db_alias).filter( | ||||||
|  |             pk=user.pk, user=user.user_id, provider=user.provider_id | ||||||
|  |         ).update(scim_id=user.pk, id=uuid.uuid4()) | ||||||
|  |  | ||||||
|  |     print("\nFixing primary key for SCIM groups, this might take a couple of minutes...") | ||||||
|  |     for group in progress_bar(SCIMGroup.objects.using(db_alias).all()): | ||||||
|  |         SCIMGroup.objects.using(db_alias).filter( | ||||||
|  |             pk=group.pk, group=group.group_id, provider=group.provider_id | ||||||
|  |         ).update(scim_id=group.pk, id=uuid.uuid4()) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ( | ||||||
|  |             "authentik_providers_scim", | ||||||
|  |             "0001_squashed_0006_rename_parent_group_scimprovider_filter_group", | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="scimgroup", | ||||||
|  |             name="scim_id", | ||||||
|  |             field=models.TextField(default="temp"), | ||||||
|  |             preserve_default=False, | ||||||
|  |         ), | ||||||
|  |         migrations.AddField( | ||||||
|  |             model_name="scimuser", | ||||||
|  |             name="scim_id", | ||||||
|  |             field=models.TextField(default="temp"), | ||||||
|  |             preserve_default=False, | ||||||
|  |         ), | ||||||
|  |         migrations.RunPython(fix_scim_user_group_pk), | ||||||
|  |         migrations.AlterField( | ||||||
|  |             model_name="scimgroup", | ||||||
|  |             name="id", | ||||||
|  |             field=models.UUIDField( | ||||||
|  |                 default=uuid.uuid4, editable=False, primary_key=True, serialize=False | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |         migrations.AlterField( | ||||||
|  |             model_name="scimuser", | ||||||
|  |             name="id", | ||||||
|  |             field=models.UUIDField( | ||||||
|  |                 default=uuid.uuid4, editable=False, primary_key=True, serialize=False | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |         migrations.AlterField(model_name="scimuser", name="scim_id", field=models.TextField()), | ||||||
|  |         migrations.AlterField(model_name="scimgroup", name="scim_id", field=models.TextField()), | ||||||
|  |         migrations.AlterUniqueTogether( | ||||||
|  |             name="scimgroup", | ||||||
|  |             unique_together={("scim_id", "group", "provider")}, | ||||||
|  |         ), | ||||||
|  |         migrations.AlterUniqueTogether( | ||||||
|  |             name="scimuser", | ||||||
|  |             unique_together={("scim_id", "user", "provider")}, | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -1,5 +1,7 @@ | |||||||
| """SCIM Provider models""" | """SCIM Provider models""" | ||||||
|  |  | ||||||
|  | from uuid import uuid4 | ||||||
|  |  | ||||||
| from django.core.cache import cache | from django.core.cache import cache | ||||||
| from django.db import models | from django.db import models | ||||||
| from django.db.models import QuerySet | from django.db.models import QuerySet | ||||||
| @ -97,26 +99,28 @@ class SCIMMapping(PropertyMapping): | |||||||
| class SCIMUser(models.Model): | class SCIMUser(models.Model): | ||||||
|     """Mapping of a user and provider to a SCIM user ID""" |     """Mapping of a user and provider to a SCIM user ID""" | ||||||
|  |  | ||||||
|     id = models.TextField(primary_key=True) |     id = models.UUIDField(primary_key=True, editable=False, default=uuid4) | ||||||
|  |     scim_id = models.TextField() | ||||||
|     user = models.ForeignKey(User, on_delete=models.CASCADE) |     user = models.ForeignKey(User, on_delete=models.CASCADE) | ||||||
|     provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE) |     provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE) | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         unique_together = (("id", "user", "provider"),) |         unique_together = (("scim_id", "user", "provider"),) | ||||||
|  |  | ||||||
|     def __str__(self) -> str: |     def __str__(self) -> str: | ||||||
|         return f"SCIM User {self.user.username} to {self.provider.name}" |         return f"SCIM User {self.user_id} to {self.provider_id}" | ||||||
|  |  | ||||||
|  |  | ||||||
| class SCIMGroup(models.Model): | class SCIMGroup(models.Model): | ||||||
|     """Mapping of a group and provider to a SCIM user ID""" |     """Mapping of a group and provider to a SCIM user ID""" | ||||||
|  |  | ||||||
|     id = models.TextField(primary_key=True) |     id = models.UUIDField(primary_key=True, editable=False, default=uuid4) | ||||||
|  |     scim_id = models.TextField() | ||||||
|     group = models.ForeignKey(Group, on_delete=models.CASCADE) |     group = models.ForeignKey(Group, on_delete=models.CASCADE) | ||||||
|     provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE) |     provider = models.ForeignKey(SCIMProvider, on_delete=models.CASCADE) | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         unique_together = (("id", "group", "provider"),) |         unique_together = (("scim_id", "group", "provider"),) | ||||||
|  |  | ||||||
|     def __str__(self) -> str: |     def __str__(self) -> str: | ||||||
|         return f"SCIM Group {self.group.name} to {self.provider.name}" |         return f"SCIM Group {self.group_id} to {self.provider_id}" | ||||||
|  | |||||||
| @ -9,7 +9,7 @@ from structlog.stdlib import get_logger | |||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
| from authentik.lib.utils.reflection import class_to_path | from authentik.lib.utils.reflection import class_to_path | ||||||
| from authentik.providers.scim.models import SCIMProvider | from authentik.providers.scim.models import SCIMProvider | ||||||
| from authentik.providers.scim.tasks import scim_signal_direct, scim_signal_m2m, scim_sync | from authentik.providers.scim.tasks import scim_signal_direct, scim_signal_m2m, scim_task_wrapper | ||||||
|  |  | ||||||
| LOGGER = get_logger() | LOGGER = get_logger() | ||||||
|  |  | ||||||
| @ -17,7 +17,7 @@ LOGGER = get_logger() | |||||||
| @receiver(post_save, sender=SCIMProvider) | @receiver(post_save, sender=SCIMProvider) | ||||||
| def post_save_provider(sender: type[Model], instance, created: bool, **_): | def post_save_provider(sender: type[Model], instance, created: bool, **_): | ||||||
|     """Trigger sync when SCIM provider is saved""" |     """Trigger sync when SCIM provider is saved""" | ||||||
|     scim_sync.delay(instance.pk) |     scim_task_wrapper(instance.pk) | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(post_save, sender=User) | @receiver(post_save, sender=User) | ||||||
|  | |||||||
| @ -38,7 +38,23 @@ def client_for_model(provider: SCIMProvider, model: Model) -> SCIMClient: | |||||||
| def scim_sync_all(): | def scim_sync_all(): | ||||||
|     """Run sync for all providers""" |     """Run sync for all providers""" | ||||||
|     for provider in SCIMProvider.objects.filter(backchannel_application__isnull=False): |     for provider in SCIMProvider.objects.filter(backchannel_application__isnull=False): | ||||||
|         scim_sync.delay(provider.pk) |         scim_task_wrapper(provider.pk) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def scim_task_wrapper(provider_pk: int): | ||||||
|  |     """Wrap scim_sync to set the correct timeouts""" | ||||||
|  |     provider: SCIMProvider = SCIMProvider.objects.filter( | ||||||
|  |         pk=provider_pk, backchannel_application__isnull=False | ||||||
|  |     ).first() | ||||||
|  |     if not provider: | ||||||
|  |         return | ||||||
|  |     users_paginator = Paginator(provider.get_user_qs(), PAGE_SIZE) | ||||||
|  |     groups_paginator = Paginator(provider.get_group_qs(), PAGE_SIZE) | ||||||
|  |     soft_time_limit = (users_paginator.num_pages + groups_paginator.num_pages) * PAGE_TIMEOUT | ||||||
|  |     time_limit = soft_time_limit * 1.5 | ||||||
|  |     return scim_sync.apply_async( | ||||||
|  |         (provider.pk,), time_limit=int(time_limit), soft_time_limit=int(soft_time_limit) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @CELERY_APP.task(bind=True, base=SystemTask) | @CELERY_APP.task(bind=True, base=SystemTask) | ||||||
| @ -60,7 +76,7 @@ def scim_sync(self: SystemTask, provider_pk: int) -> None: | |||||||
|     users_paginator = Paginator(provider.get_user_qs(), PAGE_SIZE) |     users_paginator = Paginator(provider.get_user_qs(), PAGE_SIZE) | ||||||
|     groups_paginator = Paginator(provider.get_group_qs(), PAGE_SIZE) |     groups_paginator = Paginator(provider.get_group_qs(), PAGE_SIZE) | ||||||
|     self.soft_time_limit = self.time_limit = ( |     self.soft_time_limit = self.time_limit = ( | ||||||
|         users_paginator.count + groups_paginator.count |         users_paginator.num_pages + groups_paginator.num_pages | ||||||
|     ) * PAGE_TIMEOUT |     ) * PAGE_TIMEOUT | ||||||
|     with allow_join_result(): |     with allow_join_result(): | ||||||
|         try: |         try: | ||||||
|  | |||||||
| @ -8,7 +8,7 @@ from authentik.core.models import Application, Group, User | |||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.providers.scim.clients.schema import ServiceProviderConfiguration | from authentik.providers.scim.clients.schema import ServiceProviderConfiguration | ||||||
| from authentik.providers.scim.models import SCIMMapping, SCIMProvider | from authentik.providers.scim.models import SCIMMapping, SCIMProvider | ||||||
| from authentik.providers.scim.tasks import scim_sync | from authentik.providers.scim.tasks import scim_task_wrapper | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -79,7 +79,7 @@ class SCIMMembershipTests(TestCase): | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             self.configure() |             self.configure() | ||||||
|             scim_sync.delay(self.provider.pk).get() |             scim_task_wrapper(self.provider.pk).get() | ||||||
|  |  | ||||||
|             self.assertEqual(mocker.call_count, 6) |             self.assertEqual(mocker.call_count, 6) | ||||||
|             self.assertEqual(mocker.request_history[0].method, "GET") |             self.assertEqual(mocker.request_history[0].method, "GET") | ||||||
| @ -169,7 +169,7 @@ class SCIMMembershipTests(TestCase): | |||||||
|             ) |             ) | ||||||
|  |  | ||||||
|             self.configure() |             self.configure() | ||||||
|             scim_sync.delay(self.provider.pk).get() |             scim_task_wrapper(self.provider.pk).get() | ||||||
|  |  | ||||||
|             self.assertEqual(mocker.call_count, 6) |             self.assertEqual(mocker.call_count, 6) | ||||||
|             self.assertEqual(mocker.request_history[0].method, "GET") |             self.assertEqual(mocker.request_history[0].method, "GET") | ||||||
|  | |||||||
| @ -10,7 +10,7 @@ from authentik.blueprints.tests import apply_blueprint | |||||||
| from authentik.core.models import Application, Group, User | from authentik.core.models import Application, Group, User | ||||||
| from authentik.lib.generators import generate_id | from authentik.lib.generators import generate_id | ||||||
| from authentik.providers.scim.models import SCIMMapping, SCIMProvider | from authentik.providers.scim.models import SCIMMapping, SCIMProvider | ||||||
| from authentik.providers.scim.tasks import scim_sync | from authentik.providers.scim.tasks import scim_task_wrapper | ||||||
| from authentik.tenants.models import Tenant | from authentik.tenants.models import Tenant | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -88,6 +88,72 @@ class SCIMUserTests(TestCase): | |||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     @Mocker() | ||||||
|  |     def test_user_create_different_provider_same_id(self, mock: Mocker): | ||||||
|  |         """Test user creation with multiple providers that happen | ||||||
|  |         to return the same object ID""" | ||||||
|  |         # Create duplicate provider | ||||||
|  |         provider: SCIMProvider = SCIMProvider.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             url="https://localhost", | ||||||
|  |             token=generate_id(), | ||||||
|  |             exclude_users_service_account=True, | ||||||
|  |         ) | ||||||
|  |         app: Application = Application.objects.create( | ||||||
|  |             name=generate_id(), | ||||||
|  |             slug=generate_id(), | ||||||
|  |         ) | ||||||
|  |         app.backchannel_providers.add(provider) | ||||||
|  |         provider.property_mappings.add( | ||||||
|  |             SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user") | ||||||
|  |         ) | ||||||
|  |         provider.property_mappings_group.add( | ||||||
|  |             SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group") | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         scim_id = generate_id() | ||||||
|  |         mock.get( | ||||||
|  |             "https://localhost/ServiceProviderConfig", | ||||||
|  |             json={}, | ||||||
|  |         ) | ||||||
|  |         mock.post( | ||||||
|  |             "https://localhost/Users", | ||||||
|  |             json={ | ||||||
|  |                 "id": scim_id, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |         uid = generate_id() | ||||||
|  |         user = User.objects.create( | ||||||
|  |             username=uid, | ||||||
|  |             name=f"{uid} {uid}", | ||||||
|  |             email=f"{uid}@goauthentik.io", | ||||||
|  |         ) | ||||||
|  |         self.assertEqual(mock.call_count, 4) | ||||||
|  |         self.assertEqual(mock.request_history[0].method, "GET") | ||||||
|  |         self.assertEqual(mock.request_history[1].method, "POST") | ||||||
|  |         self.assertJSONEqual( | ||||||
|  |             mock.request_history[1].body, | ||||||
|  |             { | ||||||
|  |                 "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], | ||||||
|  |                 "active": True, | ||||||
|  |                 "emails": [ | ||||||
|  |                     { | ||||||
|  |                         "primary": True, | ||||||
|  |                         "type": "other", | ||||||
|  |                         "value": f"{uid}@goauthentik.io", | ||||||
|  |                     } | ||||||
|  |                 ], | ||||||
|  |                 "externalId": user.uid, | ||||||
|  |                 "name": { | ||||||
|  |                     "familyName": uid, | ||||||
|  |                     "formatted": f"{uid} {uid}", | ||||||
|  |                     "givenName": uid, | ||||||
|  |                 }, | ||||||
|  |                 "displayName": f"{uid} {uid}", | ||||||
|  |                 "userName": uid, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     @Mocker() |     @Mocker() | ||||||
|     def test_user_create_update(self, mock: Mocker): |     def test_user_create_update(self, mock: Mocker): | ||||||
|         """Test user creation and update""" |         """Test user creation and update""" | ||||||
| @ -236,7 +302,7 @@ class SCIMUserTests(TestCase): | |||||||
|             email=f"{uid}@goauthentik.io", |             email=f"{uid}@goauthentik.io", | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         scim_sync.delay(self.provider.pk).get() |         scim_task_wrapper(self.provider.pk).get() | ||||||
|  |  | ||||||
|         self.assertEqual(mock.call_count, 5) |         self.assertEqual(mock.call_count, 5) | ||||||
|         self.assertEqual(mock.request_history[0].method, "GET") |         self.assertEqual(mock.request_history[0].method, "GET") | ||||||
|  | |||||||
| @ -25,7 +25,7 @@ class ObjectFilter(ObjectPermissionsFilter): | |||||||
|         # Outposts (which are the only objects using internal service accounts) |         # Outposts (which are the only objects using internal service accounts) | ||||||
|         # except requests to return an empty list when they have no objects |         # except requests to return an empty list when they have no objects | ||||||
|         # assigned |         # assigned | ||||||
|         if request.user.type == UserTypes.INTERNAL_SERVICE_ACCOUNT: |         if getattr(request.user, "type", UserTypes.INTERNAL) == UserTypes.INTERNAL_SERVICE_ACCOUNT: | ||||||
|             return queryset |             return queryset | ||||||
|         if not queryset.exists(): |         if not queryset.exists(): | ||||||
|             # User doesn't have direct permission to all objects |             # User doesn't have direct permission to all objects | ||||||
|  | |||||||
| @ -376,7 +376,13 @@ CELERY = { | |||||||
|     "task_default_queue": "authentik", |     "task_default_queue": "authentik", | ||||||
|     "broker_url": CONFIG.get("broker.url") or redis_url(CONFIG.get("redis.db")), |     "broker_url": CONFIG.get("broker.url") or redis_url(CONFIG.get("redis.db")), | ||||||
|     "result_backend": CONFIG.get("result_backend.url") or redis_url(CONFIG.get("redis.db")), |     "result_backend": CONFIG.get("result_backend.url") or redis_url(CONFIG.get("redis.db")), | ||||||
|     "broker_transport_options": CONFIG.get_dict_from_b64_json("broker.transport_options"), |     "broker_transport_options": CONFIG.get_dict_from_b64_json( | ||||||
|  |         "broker.transport_options", {"retry_policy": {"timeout": 5.0}} | ||||||
|  |     ), | ||||||
|  |     "result_backend_transport_options": CONFIG.get_dict_from_b64_json( | ||||||
|  |         "result_backend.transport_options", {"retry_policy": {"timeout": 5.0}} | ||||||
|  |     ), | ||||||
|  |     "redis_retry_on_timeout": True, | ||||||
| } | } | ||||||
|  |  | ||||||
| # Sentry integration | # Sentry integration | ||||||
|  | |||||||
| @ -80,7 +80,7 @@ class OAuth2Client(BaseOAuthClient): | |||||||
|             access_token_url = self.source.source_type.access_token_url or "" |             access_token_url = self.source.source_type.access_token_url or "" | ||||||
|             if self.source.source_type.urls_customizable and self.source.access_token_url: |             if self.source.source_type.urls_customizable and self.source.access_token_url: | ||||||
|                 access_token_url = self.source.access_token_url |                 access_token_url = self.source.access_token_url | ||||||
|             response = self.session.request( |             response = self.do_request( | ||||||
|                 "post", access_token_url, data=args, headers=self._default_headers, **request_kwargs |                 "post", access_token_url, data=args, headers=self._default_headers, **request_kwargs | ||||||
|             ) |             ) | ||||||
|             response.raise_for_status() |             response.raise_for_status() | ||||||
|  | |||||||
							
								
								
									
										37
									
								
								authentik/sources/oauth/tests/test_type_apple.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								authentik/sources/oauth/tests/test_type_apple.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,37 @@ | |||||||
|  | """Apple Type tests""" | ||||||
|  |  | ||||||
|  | from django.test import RequestFactory, TestCase | ||||||
|  | from guardian.shortcuts import get_anonymous_user | ||||||
|  |  | ||||||
|  | from authentik.lib.generators import generate_id | ||||||
|  | from authentik.lib.tests.utils import dummy_get_response | ||||||
|  | from authentik.root.middleware import SessionMiddleware | ||||||
|  | from authentik.sources.oauth.models import OAuthSource | ||||||
|  | from authentik.sources.oauth.types.registry import registry | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestTypeApple(TestCase): | ||||||
|  |     """OAuth Source tests""" | ||||||
|  |  | ||||||
|  |     def setUp(self): | ||||||
|  |         self.source = OAuthSource.objects.create( | ||||||
|  |             name="test", | ||||||
|  |             slug="test", | ||||||
|  |             provider_type="apple", | ||||||
|  |             authorization_url="", | ||||||
|  |             profile_url="", | ||||||
|  |             consumer_key=generate_id(), | ||||||
|  |         ) | ||||||
|  |         self.factory = RequestFactory() | ||||||
|  |  | ||||||
|  |     def test_login_challenge(self): | ||||||
|  |         """Test login_challenge""" | ||||||
|  |         request = self.factory.get("/") | ||||||
|  |         request.user = get_anonymous_user() | ||||||
|  |  | ||||||
|  |         middleware = SessionMiddleware(dummy_get_response) | ||||||
|  |         middleware.process_request(request) | ||||||
|  |         request.session.save() | ||||||
|  |         oauth_type = registry.find_type("apple") | ||||||
|  |         challenge = oauth_type().login_challenge(self.source, request) | ||||||
|  |         self.assertTrue(challenge.is_valid(raise_exception=True)) | ||||||
| @ -125,7 +125,7 @@ class AppleType(SourceType): | |||||||
|         ) |         ) | ||||||
|         args = apple_client.get_redirect_args() |         args = apple_client.get_redirect_args() | ||||||
|         return AppleLoginChallenge( |         return AppleLoginChallenge( | ||||||
|             instance={ |             data={ | ||||||
|                 "client_id": apple_client.get_client_id(), |                 "client_id": apple_client.get_client_id(), | ||||||
|                 "scope": "name email", |                 "scope": "name email", | ||||||
|                 "redirect_uri": args["redirect_uri"], |                 "redirect_uri": args["redirect_uri"], | ||||||
|  | |||||||
| @ -66,7 +66,7 @@ class PlexSource(Source): | |||||||
|             icon = static("authentik/sources/plex.svg") |             icon = static("authentik/sources/plex.svg") | ||||||
|         return UILoginButton( |         return UILoginButton( | ||||||
|             challenge=PlexAuthenticationChallenge( |             challenge=PlexAuthenticationChallenge( | ||||||
|                 { |                 data={ | ||||||
|                     "type": ChallengeTypes.NATIVE.value, |                     "type": ChallengeTypes.NATIVE.value, | ||||||
|                     "component": "ak-source-plex", |                     "component": "ak-source-plex", | ||||||
|                     "client_id": self.client_id, |                     "client_id": self.client_id, | ||||||
|  | |||||||
| @ -40,6 +40,11 @@ class TestPlexSource(TestCase): | |||||||
|             slug="test", |             slug="test", | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_login_challenge(self): | ||||||
|  |         """Test login_challenge""" | ||||||
|  |         ui_login_button = self.source.ui_login_button(None) | ||||||
|  |         self.assertTrue(ui_login_button.challenge.is_valid(raise_exception=True)) | ||||||
|  |  | ||||||
|     def test_get_user_info(self): |     def test_get_user_info(self): | ||||||
|         """Test get_user_info""" |         """Test get_user_info""" | ||||||
|         token = generate_key() |         token = generate_key() | ||||||
|  | |||||||
| @ -0,0 +1,44 @@ | |||||||
|  | # Generated by Django 5.0.4 on 2024-05-01 15:44 | ||||||
|  |  | ||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_sources_saml", "0013_samlsource_verification_kp_and_more"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.AlterField( | ||||||
|  |             model_name="samlsource", | ||||||
|  |             name="digest_algorithm", | ||||||
|  |             field=models.TextField( | ||||||
|  |                 choices=[ | ||||||
|  |                     ("http://www.w3.org/2000/09/xmldsig#sha1", "SHA1"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmlenc#sha256", "SHA256"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmldsig-more#sha384", "SHA384"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmlenc#sha512", "SHA512"), | ||||||
|  |                 ], | ||||||
|  |                 default="http://www.w3.org/2001/04/xmlenc#sha256", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |         migrations.AlterField( | ||||||
|  |             model_name="samlsource", | ||||||
|  |             name="signature_algorithm", | ||||||
|  |             field=models.TextField( | ||||||
|  |                 choices=[ | ||||||
|  |                     ("http://www.w3.org/2000/09/xmldsig#rsa-sha1", "RSA-SHA1"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmldsig-more#rsa-sha256", "RSA-SHA256"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmldsig-more#rsa-sha384", "RSA-SHA384"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmldsig-more#rsa-sha512", "RSA-SHA512"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1", "ECDSA-SHA1"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256", "ECDSA-SHA256"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384", "ECDSA-SHA384"), | ||||||
|  |                     ("http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512", "ECDSA-SHA512"), | ||||||
|  |                     ("http://www.w3.org/2000/09/xmldsig#dsa-sha1", "DSA-SHA1"), | ||||||
|  |                 ], | ||||||
|  |                 default="http://www.w3.org/2001/04/xmldsig-more#rsa-sha256", | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -15,6 +15,10 @@ from authentik.flows.models import Flow | |||||||
| from authentik.lib.utils.time import timedelta_string_validator | from authentik.lib.utils.time import timedelta_string_validator | ||||||
| from authentik.sources.saml.processors.constants import ( | from authentik.sources.saml.processors.constants import ( | ||||||
|     DSA_SHA1, |     DSA_SHA1, | ||||||
|  |     ECDSA_SHA1, | ||||||
|  |     ECDSA_SHA256, | ||||||
|  |     ECDSA_SHA384, | ||||||
|  |     ECDSA_SHA512, | ||||||
|     RSA_SHA1, |     RSA_SHA1, | ||||||
|     RSA_SHA256, |     RSA_SHA256, | ||||||
|     RSA_SHA384, |     RSA_SHA384, | ||||||
| @ -143,8 +147,7 @@ class SAMLSource(Source): | |||||||
|         verbose_name=_("Signing Keypair"), |         verbose_name=_("Signing Keypair"), | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     digest_algorithm = models.CharField( |     digest_algorithm = models.TextField( | ||||||
|         max_length=50, |  | ||||||
|         choices=( |         choices=( | ||||||
|             (SHA1, _("SHA1")), |             (SHA1, _("SHA1")), | ||||||
|             (SHA256, _("SHA256")), |             (SHA256, _("SHA256")), | ||||||
| @ -153,13 +156,16 @@ class SAMLSource(Source): | |||||||
|         ), |         ), | ||||||
|         default=SHA256, |         default=SHA256, | ||||||
|     ) |     ) | ||||||
|     signature_algorithm = models.CharField( |     signature_algorithm = models.TextField( | ||||||
|         max_length=50, |  | ||||||
|         choices=( |         choices=( | ||||||
|             (RSA_SHA1, _("RSA-SHA1")), |             (RSA_SHA1, _("RSA-SHA1")), | ||||||
|             (RSA_SHA256, _("RSA-SHA256")), |             (RSA_SHA256, _("RSA-SHA256")), | ||||||
|             (RSA_SHA384, _("RSA-SHA384")), |             (RSA_SHA384, _("RSA-SHA384")), | ||||||
|             (RSA_SHA512, _("RSA-SHA512")), |             (RSA_SHA512, _("RSA-SHA512")), | ||||||
|  |             (ECDSA_SHA1, _("ECDSA-SHA1")), | ||||||
|  |             (ECDSA_SHA256, _("ECDSA-SHA256")), | ||||||
|  |             (ECDSA_SHA384, _("ECDSA-SHA384")), | ||||||
|  |             (ECDSA_SHA512, _("ECDSA-SHA512")), | ||||||
|             (DSA_SHA1, _("DSA-SHA1")), |             (DSA_SHA1, _("DSA-SHA1")), | ||||||
|         ), |         ), | ||||||
|         default=RSA_SHA256, |         default=RSA_SHA256, | ||||||
|  | |||||||
| @ -26,9 +26,16 @@ SAML_BINDING_REDIRECT = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" | |||||||
|  |  | ||||||
| DSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#dsa-sha1" | DSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#dsa-sha1" | ||||||
| RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1" | RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1" | ||||||
|  | # https://datatracker.ietf.org/doc/html/rfc4051#section-2.3.2 | ||||||
| RSA_SHA256 = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" | RSA_SHA256 = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" | ||||||
| RSA_SHA384 = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha384" | RSA_SHA384 = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha384" | ||||||
| RSA_SHA512 = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha512" | RSA_SHA512 = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha512" | ||||||
|  | # https://datatracker.ietf.org/doc/html/rfc4051#section-2.3.6 | ||||||
|  | ECDSA_SHA1 = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1" | ||||||
|  | ECDSA_SHA224 = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha224" | ||||||
|  | ECDSA_SHA256 = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256" | ||||||
|  | ECDSA_SHA384 = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384" | ||||||
|  | ECDSA_SHA512 = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512" | ||||||
|  |  | ||||||
| SHA1 = "http://www.w3.org/2000/09/xmldsig#sha1" | SHA1 = "http://www.w3.org/2000/09/xmldsig#sha1" | ||||||
| SHA256 = "http://www.w3.org/2001/04/xmlenc#sha256" | SHA256 = "http://www.w3.org/2001/04/xmlenc#sha256" | ||||||
| @ -41,6 +48,11 @@ SIGN_ALGORITHM_TRANSFORM_MAP = { | |||||||
|     RSA_SHA256: xmlsec.constants.TransformRsaSha256, |     RSA_SHA256: xmlsec.constants.TransformRsaSha256, | ||||||
|     RSA_SHA384: xmlsec.constants.TransformRsaSha384, |     RSA_SHA384: xmlsec.constants.TransformRsaSha384, | ||||||
|     RSA_SHA512: xmlsec.constants.TransformRsaSha512, |     RSA_SHA512: xmlsec.constants.TransformRsaSha512, | ||||||
|  |     ECDSA_SHA1: xmlsec.constants.TransformEcdsaSha1, | ||||||
|  |     ECDSA_SHA224: xmlsec.constants.TransformEcdsaSha224, | ||||||
|  |     ECDSA_SHA256: xmlsec.constants.TransformEcdsaSha256, | ||||||
|  |     ECDSA_SHA384: xmlsec.constants.TransformEcdsaSha384, | ||||||
|  |     ECDSA_SHA512: xmlsec.constants.TransformEcdsaSha512, | ||||||
| } | } | ||||||
|  |  | ||||||
| DIGEST_ALGORITHM_TRANSLATION_MAP = { | DIGEST_ALGORITHM_TRANSLATION_MAP = { | ||||||
|  | |||||||
| @ -60,7 +60,7 @@ class SCIMSourceUser(SerializerModel): | |||||||
|         unique_together = (("id", "user", "source"),) |         unique_together = (("id", "user", "source"),) | ||||||
|  |  | ||||||
|     def __str__(self) -> str: |     def __str__(self) -> str: | ||||||
|         return f"SCIM User {self.user.username} to {self.source.name}" |         return f"SCIM User {self.user_id} to {self.source_id}" | ||||||
|  |  | ||||||
|  |  | ||||||
| class SCIMSourceGroup(SerializerModel): | class SCIMSourceGroup(SerializerModel): | ||||||
| @ -81,4 +81,4 @@ class SCIMSourceGroup(SerializerModel): | |||||||
|         unique_together = (("id", "group", "source"),) |         unique_together = (("id", "group", "source"),) | ||||||
|  |  | ||||||
|     def __str__(self) -> str: |     def __str__(self) -> str: | ||||||
|         return f"SCIM Group {self.group.name} to {self.source.name}" |         return f"SCIM Group {self.group_id} to {self.source_id}" | ||||||
|  | |||||||
| @ -2,9 +2,11 @@ from django.db.models import Model | |||||||
| from django.db.models.signals import pre_delete, pre_save | from django.db.models.signals import pre_delete, pre_save | ||||||
| from django.dispatch import receiver | from django.dispatch import receiver | ||||||
|  |  | ||||||
| from authentik.core.models import Token, TokenIntents, User, UserTypes | from authentik.core.models import USER_PATH_SYSTEM_PREFIX, Token, TokenIntents, User, UserTypes | ||||||
| from authentik.sources.scim.models import SCIMSource | from authentik.sources.scim.models import SCIMSource | ||||||
|  |  | ||||||
|  | USER_PATH_SOURCE_SCIM = USER_PATH_SYSTEM_PREFIX + "/sources/scim" | ||||||
|  |  | ||||||
|  |  | ||||||
| @receiver(pre_save, sender=SCIMSource) | @receiver(pre_save, sender=SCIMSource) | ||||||
| def scim_source_pre_save(sender: type[Model], instance: SCIMSource, **_): | def scim_source_pre_save(sender: type[Model], instance: SCIMSource, **_): | ||||||
| @ -16,6 +18,7 @@ def scim_source_pre_save(sender: type[Model], instance: SCIMSource, **_): | |||||||
|         username=identifier, |         username=identifier, | ||||||
|         name=f"SCIM Source {instance.name} Service-Account", |         name=f"SCIM Source {instance.name} Service-Account", | ||||||
|         type=UserTypes.INTERNAL_SERVICE_ACCOUNT, |         type=UserTypes.INTERNAL_SERVICE_ACCOUNT, | ||||||
|  |         path=USER_PATH_SOURCE_SCIM, | ||||||
|     ) |     ) | ||||||
|     token = Token.objects.create( |     token = Token.objects.create( | ||||||
|         user=user, |         user=user, | ||||||
|  | |||||||
| @ -13,6 +13,7 @@ from rest_framework.request import Request | |||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
|  |  | ||||||
| from authentik.core.models import Group, User | from authentik.core.models import Group, User | ||||||
|  | from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA | ||||||
| from authentik.providers.scim.clients.schema import Group as SCIMGroupModel | from authentik.providers.scim.clients.schema import Group as SCIMGroupModel | ||||||
| from authentik.sources.scim.models import SCIMSourceGroup | from authentik.sources.scim.models import SCIMSourceGroup | ||||||
| from authentik.sources.scim.views.v2.base import SCIMView | from authentik.sources.scim.views.v2.base import SCIMView | ||||||
| @ -26,9 +27,11 @@ class GroupsView(SCIMView): | |||||||
|     def group_to_scim(self, scim_group: SCIMSourceGroup) -> dict: |     def group_to_scim(self, scim_group: SCIMSourceGroup) -> dict: | ||||||
|         """Convert Group to SCIM data""" |         """Convert Group to SCIM data""" | ||||||
|         payload = SCIMGroupModel( |         payload = SCIMGroupModel( | ||||||
|  |             schemas=[SCIM_USER_SCHEMA], | ||||||
|             id=str(scim_group.group.pk), |             id=str(scim_group.group.pk), | ||||||
|             externalId=scim_group.id, |             externalId=scim_group.id, | ||||||
|             displayName=scim_group.group.name, |             displayName=scim_group.group.name, | ||||||
|  |             members=[], | ||||||
|             meta={ |             meta={ | ||||||
|                 "resourceType": "Group", |                 "resourceType": "Group", | ||||||
|                 "location": self.request.build_absolute_uri( |                 "location": self.request.build_absolute_uri( | ||||||
| @ -42,28 +45,24 @@ class GroupsView(SCIMView): | |||||||
|                 ), |                 ), | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|         return payload.model_dump( |         for member in scim_group.group.users.order_by("pk"): | ||||||
|             mode="json", |             member: User | ||||||
|             exclude_unset=True, |             payload.members.append(GroupMember(value=str(member.uuid))) | ||||||
|         ) |         return payload.model_dump(mode="json", exclude_unset=True) | ||||||
|  |  | ||||||
|     def get(self, request: Request, group_id: str | None = None, **kwargs) -> Response: |     def get(self, request: Request, group_id: str | None = None, **kwargs) -> Response: | ||||||
|         """List Group handler""" |         """List Group handler""" | ||||||
|         if group_id: |         base_query = SCIMSourceGroup.objects.select_related("group").prefetch_related( | ||||||
|             connection = ( |             "group__users" | ||||||
|                 SCIMSourceGroup.objects.filter(source=self.source, group__group_uuid=group_id) |  | ||||||
|                 .select_related("group") |  | ||||||
|                 .first() |  | ||||||
|         ) |         ) | ||||||
|  |         if group_id: | ||||||
|  |             connection = base_query.filter(source=self.source, group__group_uuid=group_id).first() | ||||||
|             if not connection: |             if not connection: | ||||||
|                 raise Http404 |                 raise Http404 | ||||||
|             return Response(self.group_to_scim(connection)) |             return Response(self.group_to_scim(connection)) | ||||||
|         connections = ( |         connections = ( | ||||||
|             SCIMSourceGroup.objects.filter(source=self.source) |             base_query.filter(source=self.source).order_by("pk").filter(self.filter_parse(request)) | ||||||
|             .select_related("group") |  | ||||||
|             .order_by("pk") |  | ||||||
|         ) |         ) | ||||||
|         connections = connections.filter(self.filter_parse(request)) |  | ||||||
|         page = self.paginate_query(connections) |         page = self.paginate_query(connections) | ||||||
|         return Response( |         return Response( | ||||||
|             { |             { | ||||||
| @ -79,6 +78,8 @@ class GroupsView(SCIMView): | |||||||
|     def update_group(self, connection: SCIMSourceGroup | None, data: QueryDict): |     def update_group(self, connection: SCIMSourceGroup | None, data: QueryDict): | ||||||
|         """Partial update a group""" |         """Partial update a group""" | ||||||
|         group = connection.group if connection else Group() |         group = connection.group if connection else Group() | ||||||
|  |         if _group := Group.objects.filter(name=data.get("displayName")).first(): | ||||||
|  |             group = _group | ||||||
|         if "displayName" in data: |         if "displayName" in data: | ||||||
|             group.name = data.get("displayName") |             group.name = data.get("displayName") | ||||||
|         if group.name == "": |         if group.name == "": | ||||||
|  | |||||||
| @ -11,6 +11,7 @@ from rest_framework.request import Request | |||||||
| from rest_framework.response import Response | from rest_framework.response import Response | ||||||
|  |  | ||||||
| from authentik.core.models import User | from authentik.core.models import User | ||||||
|  | from authentik.providers.scim.clients.schema import SCIM_USER_SCHEMA | ||||||
| from authentik.providers.scim.clients.schema import User as SCIMUserModel | from authentik.providers.scim.clients.schema import User as SCIMUserModel | ||||||
| from authentik.sources.scim.models import SCIMSourceUser | from authentik.sources.scim.models import SCIMSourceUser | ||||||
| from authentik.sources.scim.views.v2.base import SCIMView | from authentik.sources.scim.views.v2.base import SCIMView | ||||||
| @ -33,6 +34,7 @@ class UsersView(SCIMView): | |||||||
|     def user_to_scim(self, scim_user: SCIMSourceUser) -> dict: |     def user_to_scim(self, scim_user: SCIMSourceUser) -> dict: | ||||||
|         """Convert User to SCIM data""" |         """Convert User to SCIM data""" | ||||||
|         payload = SCIMUserModel( |         payload = SCIMUserModel( | ||||||
|  |             schemas=[SCIM_USER_SCHEMA], | ||||||
|             id=str(scim_user.user.uuid), |             id=str(scim_user.user.uuid), | ||||||
|             externalId=scim_user.id, |             externalId=scim_user.id, | ||||||
|             userName=scim_user.user.username, |             userName=scim_user.user.username, | ||||||
| @ -62,10 +64,7 @@ class UsersView(SCIMView): | |||||||
|                 ), |                 ), | ||||||
|             }, |             }, | ||||||
|         ) |         ) | ||||||
|         final_payload = payload.model_dump( |         final_payload = payload.model_dump(mode="json", exclude_unset=True) | ||||||
|             mode="json", |  | ||||||
|             exclude_unset=True, |  | ||||||
|         ) |  | ||||||
|         final_payload.update(scim_user.attributes) |         final_payload.update(scim_user.attributes) | ||||||
|         return final_payload |         return final_payload | ||||||
|  |  | ||||||
| @ -99,6 +98,8 @@ class UsersView(SCIMView): | |||||||
|     def update_user(self, connection: SCIMSourceUser | None, data: QueryDict): |     def update_user(self, connection: SCIMSourceUser | None, data: QueryDict): | ||||||
|         """Partial update a user""" |         """Partial update a user""" | ||||||
|         user = connection.user if connection else User() |         user = connection.user if connection else User() | ||||||
|  |         if _user := User.objects.filter(username=data.get("userName")).first(): | ||||||
|  |             user = _user | ||||||
|         user.path = self.source.get_user_path() |         user.path = self.source.get_user_path() | ||||||
|         if "userName" in data: |         if "userName" in data: | ||||||
|             user.username = data.get("userName") |             user.username = data.get("userName") | ||||||
|  | |||||||
| @ -96,7 +96,7 @@ class DuoDevice(SerializerModel, Device): | |||||||
|         return DuoDeviceSerializer |         return DuoDeviceSerializer | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return str(self.name) or str(self.user) |         return str(self.name) or str(self.user_id) | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         verbose_name = _("Duo Device") |         verbose_name = _("Duo Device") | ||||||
|  | |||||||
| @ -221,7 +221,7 @@ class SMSDevice(SerializerModel, SideChannelDevice): | |||||||
|         return valid |         return valid | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return str(self.name) or str(self.user) |         return str(self.name) or str(self.user_id) | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         verbose_name = _("SMS Device") |         verbose_name = _("SMS Device") | ||||||
|  | |||||||
| @ -155,7 +155,7 @@ class WebAuthnDevice(SerializerModel, Device): | |||||||
|         return WebAuthnDeviceSerializer |         return WebAuthnDeviceSerializer | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return str(self.name) or str(self.user) |         return str(self.name) or str(self.user_id) | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         verbose_name = _("WebAuthn Device") |         verbose_name = _("WebAuthn Device") | ||||||
|  | |||||||
| @ -65,7 +65,7 @@ class UserConsent(SerializerModel, ExpiringModel): | |||||||
|         return UserConsentSerializer |         return UserConsentSerializer | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return f"User Consent {self.application} by {self.user}" |         return f"User Consent {self.application_id} by {self.user_id}" | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         unique_together = (("user", "application", "permissions"),) |         unique_together = (("user", "application", "permissions"),) | ||||||
|  | |||||||
| @ -79,7 +79,7 @@ class Invitation(SerializerModel, ExpiringModel): | |||||||
|         return InvitationSerializer |         return InvitationSerializer | ||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return f"Invitation {str(self.invite_uuid)} created by {self.created_by}" |         return f"Invitation {str(self.invite_uuid)} created by {self.created_by_id}" | ||||||
|  |  | ||||||
|     class Meta: |     class Meta: | ||||||
|         verbose_name = _("Invitation") |         verbose_name = _("Invitation") | ||||||
|  | |||||||
| @ -0,0 +1,23 @@ | |||||||
|  | # Generated by Django 5.0.4 on 2024-05-01 15:32 | ||||||
|  |  | ||||||
|  | import authentik.lib.utils.time | ||||||
|  | from django.db import migrations, models | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Migration(migrations.Migration): | ||||||
|  |  | ||||||
|  |     dependencies = [ | ||||||
|  |         ("authentik_tenants", "0002_tenant_default_token_duration_and_more"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     operations = [ | ||||||
|  |         migrations.AlterField( | ||||||
|  |             model_name="tenant", | ||||||
|  |             name="default_token_duration", | ||||||
|  |             field=models.TextField( | ||||||
|  |                 default="days=1", | ||||||
|  |                 help_text="Default token duration", | ||||||
|  |                 validators=[authentik.lib.utils.time.timedelta_string_validator], | ||||||
|  |             ), | ||||||
|  |         ), | ||||||
|  |     ] | ||||||
| @ -23,7 +23,7 @@ LOGGER = get_logger() | |||||||
|  |  | ||||||
| VALID_SCHEMA_NAME = re.compile(r"^t_[a-z0-9]{1,61}$") | VALID_SCHEMA_NAME = re.compile(r"^t_[a-z0-9]{1,61}$") | ||||||
|  |  | ||||||
| DEFAULT_TOKEN_DURATION = "minutes=30"  # nosec | DEFAULT_TOKEN_DURATION = "days=1"  # nosec | ||||||
| DEFAULT_TOKEN_LENGTH = 60 | DEFAULT_TOKEN_LENGTH = 60 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -3,6 +3,7 @@ | |||||||
| from tenant_schemas_celery.scheduler import ( | from tenant_schemas_celery.scheduler import ( | ||||||
|     TenantAwarePersistentScheduler as BaseTenantAwarePersistentScheduler, |     TenantAwarePersistentScheduler as BaseTenantAwarePersistentScheduler, | ||||||
| ) | ) | ||||||
|  | from tenant_schemas_celery.scheduler import TenantAwareScheduleEntry | ||||||
|  |  | ||||||
|  |  | ||||||
| class TenantAwarePersistentScheduler(BaseTenantAwarePersistentScheduler): | class TenantAwarePersistentScheduler(BaseTenantAwarePersistentScheduler): | ||||||
| @ -11,3 +12,11 @@ class TenantAwarePersistentScheduler(BaseTenantAwarePersistentScheduler): | |||||||
|     @classmethod |     @classmethod | ||||||
|     def get_queryset(cls): |     def get_queryset(cls): | ||||||
|         return super().get_queryset().filter(ready=True) |         return super().get_queryset().filter(ready=True) | ||||||
|  |  | ||||||
|  |     def apply_entry(self, entry: TenantAwareScheduleEntry, producer=None): | ||||||
|  |         # https://github.com/maciej-gol/tenant-schemas-celery/blob/master/tenant_schemas_celery/scheduler.py#L85 | ||||||
|  |         # When (as by default) no tenant schemas are set, the public schema is excluded | ||||||
|  |         # so we need to explicitly include it here, otherwise the task is not executed | ||||||
|  |         if entry.tenant_schemas is None: | ||||||
|  |             entry.tenant_schemas = self.get_queryset().values_list("schema_name", flat=True) | ||||||
|  |         return super().apply_entry(entry, producer) | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ | |||||||
|     "$schema": "http://json-schema.org/draft-07/schema", |     "$schema": "http://json-schema.org/draft-07/schema", | ||||||
|     "$id": "https://goauthentik.io/blueprints/schema.json", |     "$id": "https://goauthentik.io/blueprints/schema.json", | ||||||
|     "type": "object", |     "type": "object", | ||||||
|     "title": "authentik 2024.2.3 Blueprint schema", |     "title": "authentik 2024.4.4 Blueprint schema", | ||||||
|     "required": [ |     "required": [ | ||||||
|         "version", |         "version", | ||||||
|         "entries" |         "entries" | ||||||
| @ -4131,6 +4131,10 @@ | |||||||
|                         "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256", |                         "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256", | ||||||
|                         "http://www.w3.org/2001/04/xmldsig-more#rsa-sha384", |                         "http://www.w3.org/2001/04/xmldsig-more#rsa-sha384", | ||||||
|                         "http://www.w3.org/2001/04/xmldsig-more#rsa-sha512", |                         "http://www.w3.org/2001/04/xmldsig-more#rsa-sha512", | ||||||
|  |                         "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1", | ||||||
|  |                         "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256", | ||||||
|  |                         "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384", | ||||||
|  |                         "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512", | ||||||
|                         "http://www.w3.org/2000/09/xmldsig#dsa-sha1" |                         "http://www.w3.org/2000/09/xmldsig#dsa-sha1" | ||||||
|                     ], |                     ], | ||||||
|                     "title": "Signature algorithm" |                     "title": "Signature algorithm" | ||||||
| @ -4935,6 +4939,10 @@ | |||||||
|                         "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256", |                         "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256", | ||||||
|                         "http://www.w3.org/2001/04/xmldsig-more#rsa-sha384", |                         "http://www.w3.org/2001/04/xmldsig-more#rsa-sha384", | ||||||
|                         "http://www.w3.org/2001/04/xmldsig-more#rsa-sha512", |                         "http://www.w3.org/2001/04/xmldsig-more#rsa-sha512", | ||||||
|  |                         "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1", | ||||||
|  |                         "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256", | ||||||
|  |                         "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384", | ||||||
|  |                         "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512", | ||||||
|                         "http://www.w3.org/2000/09/xmldsig#dsa-sha1" |                         "http://www.w3.org/2000/09/xmldsig#dsa-sha1" | ||||||
|                     ], |                     ], | ||||||
|                     "title": "Signature algorithm" |                     "title": "Signature algorithm" | ||||||
|  | |||||||
| @ -32,7 +32,7 @@ services: | |||||||
|     volumes: |     volumes: | ||||||
|       - redis:/data |       - redis:/data | ||||||
|   server: |   server: | ||||||
|     image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.2.3} |     image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.4.4} | ||||||
|     restart: unless-stopped |     restart: unless-stopped | ||||||
|     command: server |     command: server | ||||||
|     environment: |     environment: | ||||||
| @ -53,7 +53,7 @@ services: | |||||||
|       - postgresql |       - postgresql | ||||||
|       - redis |       - redis | ||||||
|   worker: |   worker: | ||||||
|     image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.2.3} |     image: ${AUTHENTIK_IMAGE:-ghcr.io/goauthentik/server}:${AUTHENTIK_TAG:-2024.4.4} | ||||||
|     restart: unless-stopped |     restart: unless-stopped | ||||||
|     command: worker |     command: worker | ||||||
|     environment: |     environment: | ||||||
|  | |||||||
| @ -29,4 +29,4 @@ func UserAgent() string { | |||||||
| 	return fmt.Sprintf("authentik@%s", FullVersion()) | 	return fmt.Sprintf("authentik@%s", FullVersion()) | ||||||
| } | } | ||||||
|  |  | ||||||
| const VERSION = "2024.2.3" | const VERSION = "2024.4.4" | ||||||
|  | |||||||
| @ -54,7 +54,7 @@ function cleanup { | |||||||
| } | } | ||||||
|  |  | ||||||
| function prepare_debug { | function prepare_debug { | ||||||
|     poetry install --no-ansi --no-interaction |     VIRTUAL_ENV=/ak-root/venv poetry install --no-ansi --no-interaction | ||||||
|     touch /unittest.xml |     touch /unittest.xml | ||||||
|     chown authentik:authentik /unittest.xml |     chown authentik:authentik /unittest.xml | ||||||
| } | } | ||||||
|  | |||||||
| @ -117,6 +117,8 @@ def run_migrations(): | |||||||
|         ) |         ) | ||||||
|     finally: |     finally: | ||||||
|         release_lock(curr) |         release_lock(curr) | ||||||
|  |         curr.close() | ||||||
|  |         conn.close() | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|  | |||||||
| @ -3,7 +3,6 @@ | |||||||
| import authentik. This is done by the dockerfile.""" | import authentik. This is done by the dockerfile.""" | ||||||
| from sys import exit as sysexit | from sys import exit as sysexit | ||||||
| from time import sleep | from time import sleep | ||||||
| from urllib.parse import quote_plus |  | ||||||
|  |  | ||||||
| from psycopg import OperationalError, connect | from psycopg import OperationalError, connect | ||||||
| from redis import Redis | from redis import Redis | ||||||
| @ -35,7 +34,7 @@ def check_postgres(): | |||||||
|  |  | ||||||
|  |  | ||||||
| def check_redis(): | def check_redis(): | ||||||
|     url = redis_url(CONFIG.get("redis.db")) |     url = CONFIG.get("cache.url") or redis_url(CONFIG.get("redis.db")) | ||||||
|     while True: |     while True: | ||||||
|         try: |         try: | ||||||
|             redis = Redis.from_url(url) |             redis = Redis.from_url(url) | ||||||
| @ -43,10 +42,7 @@ def check_redis(): | |||||||
|             break |             break | ||||||
|         except RedisError as exc: |         except RedisError as exc: | ||||||
|             sleep(1) |             sleep(1) | ||||||
|             sanitized_url = url.replace(quote_plus(CONFIG.get("redis.password")), "******") |             CONFIG.log("info", f"Redis Connection failed, retrying... ({exc})") | ||||||
|             CONFIG.log( |  | ||||||
|                 "info", f"Redis Connection failed, retrying... ({exc})", redis_url=sanitized_url |  | ||||||
|             ) |  | ||||||
|     CONFIG.log("info", "Redis Connection successful") |     CONFIG.log("info", "Redis Connection successful") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										1910
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										1910
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -1,6 +1,6 @@ | |||||||
| [tool.poetry] | [tool.poetry] | ||||||
| name = "authentik" | name = "authentik" | ||||||
| version = "2024.2.3" | version = "2024.4.4" | ||||||
| description = "" | description = "" | ||||||
| authors = ["authentik Team <hello@goauthentik.io>"] | authors = ["authentik Team <hello@goauthentik.io>"] | ||||||
|  |  | ||||||
| @ -89,6 +89,7 @@ channels = { version = "*", extras = ["daphne"] } | |||||||
| channels-redis = "*" | channels-redis = "*" | ||||||
| codespell = "*" | codespell = "*" | ||||||
| colorama = "*" | colorama = "*" | ||||||
|  | cryptography = "*" | ||||||
| dacite = "*" | dacite = "*" | ||||||
| deepmerge = "*" | deepmerge = "*" | ||||||
| defusedxml = "*" | defusedxml = "*" | ||||||
| @ -101,7 +102,7 @@ django-redis = "*" | |||||||
| django-storages = { extras = ["s3"], version = "*" } | django-storages = { extras = ["s3"], version = "*" } | ||||||
| # See https://github.com/django-tenants/django-tenants/pull/997 | # See https://github.com/django-tenants/django-tenants/pull/997 | ||||||
| django-tenants = { git = "https://github.com/rissson/django-tenants.git", branch="authentik-fixes" } | django-tenants = { git = "https://github.com/rissson/django-tenants.git", branch="authentik-fixes" } | ||||||
| djangorestframework = "*" | djangorestframework = "3.14.0" | ||||||
| djangorestframework-guardian = "*" | djangorestframework-guardian = "*" | ||||||
| docker = "*" | docker = "*" | ||||||
| drf-spectacular = "*" | drf-spectacular = "*" | ||||||
| @ -115,17 +116,11 @@ gunicorn = "*" | |||||||
| jsonpatch = "*" | jsonpatch = "*" | ||||||
| kubernetes = "*" | kubernetes = "*" | ||||||
| ldap3 = "*" | ldap3 = "*" | ||||||
| lxml = [ | lxml = "*" | ||||||
|     # 5.0.0 works with libxml2 2.11.x, which is standard on brew |  | ||||||
|     { version = "5.0.0", platform = "darwin" }, |  | ||||||
|     # 4.9.x works with previous libxml2 versions, which is what we get on linux |  | ||||||
|     { version = "4.9.4", platform = "linux" }, |  | ||||||
| ] |  | ||||||
| opencontainers = { extras = ["reggie"], version = "*" } | opencontainers = { extras = ["reggie"], version = "*" } | ||||||
| packaging = "*" | packaging = "*" | ||||||
| paramiko = "*" | paramiko = "*" | ||||||
| psycopg = { extras = ["c"], version = "*" } | psycopg = { extras = ["c"], version = "*" } | ||||||
| pycryptodome = "*" |  | ||||||
| pydantic = "*" | pydantic = "*" | ||||||
| pydantic-scim = "*" | pydantic-scim = "*" | ||||||
| pyjwt = "*" | pyjwt = "*" | ||||||
|  | |||||||
							
								
								
									
										23
									
								
								schema.yml
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								schema.yml
									
									
									
									
									
								
							| @ -1,7 +1,7 @@ | |||||||
| openapi: 3.0.3 | openapi: 3.0.3 | ||||||
| info: | info: | ||||||
|   title: authentik |   title: authentik | ||||||
|   version: 2024.2.3 |   version: 2024.4.4 | ||||||
|   description: Making authentication simple. |   description: Making authentication simple. | ||||||
|   contact: |   contact: | ||||||
|     email: hello@goauthentik.io |     email: hello@goauthentik.io | ||||||
| @ -17051,6 +17051,10 @@ paths: | |||||||
|           enum: |           enum: | ||||||
|           - http://www.w3.org/2000/09/xmldsig#dsa-sha1 |           - http://www.w3.org/2000/09/xmldsig#dsa-sha1 | ||||||
|           - http://www.w3.org/2000/09/xmldsig#rsa-sha1 |           - http://www.w3.org/2000/09/xmldsig#rsa-sha1 | ||||||
|  |           - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1 | ||||||
|  |           - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256 | ||||||
|  |           - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384 | ||||||
|  |           - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512 | ||||||
|           - http://www.w3.org/2001/04/xmldsig-more#rsa-sha256 |           - http://www.w3.org/2001/04/xmldsig-more#rsa-sha256 | ||||||
|           - http://www.w3.org/2001/04/xmldsig-more#rsa-sha384 |           - http://www.w3.org/2001/04/xmldsig-more#rsa-sha384 | ||||||
|           - http://www.w3.org/2001/04/xmldsig-more#rsa-sha512 |           - http://www.w3.org/2001/04/xmldsig-more#rsa-sha512 | ||||||
| @ -20910,6 +20914,10 @@ paths: | |||||||
|           enum: |           enum: | ||||||
|           - http://www.w3.org/2000/09/xmldsig#dsa-sha1 |           - http://www.w3.org/2000/09/xmldsig#dsa-sha1 | ||||||
|           - http://www.w3.org/2000/09/xmldsig#rsa-sha1 |           - http://www.w3.org/2000/09/xmldsig#rsa-sha1 | ||||||
|  |           - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1 | ||||||
|  |           - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256 | ||||||
|  |           - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384 | ||||||
|  |           - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512 | ||||||
|           - http://www.w3.org/2001/04/xmldsig-more#rsa-sha256 |           - http://www.w3.org/2001/04/xmldsig-more#rsa-sha256 | ||||||
|           - http://www.w3.org/2001/04/xmldsig-more#rsa-sha384 |           - http://www.w3.org/2001/04/xmldsig-more#rsa-sha384 | ||||||
|           - http://www.w3.org/2001/04/xmldsig-more#rsa-sha512 |           - http://www.w3.org/2001/04/xmldsig-more#rsa-sha512 | ||||||
| @ -30450,6 +30458,11 @@ components: | |||||||
|       - pending_user |       - pending_user | ||||||
|       - pending_user_avatar |       - pending_user_avatar | ||||||
|       - type |       - type | ||||||
|  |     AlgEnum: | ||||||
|  |       enum: | ||||||
|  |       - rsa | ||||||
|  |       - ecdsa | ||||||
|  |       type: string | ||||||
|     App: |     App: | ||||||
|       type: object |       type: object | ||||||
|       description: Serialize Application info |       description: Serialize Application info | ||||||
| @ -32107,6 +32120,10 @@ components: | |||||||
|           type: string |           type: string | ||||||
|         validity_days: |         validity_days: | ||||||
|           type: integer |           type: integer | ||||||
|  |         alg: | ||||||
|  |           allOf: | ||||||
|  |           - $ref: '#/components/schemas/AlgEnum' | ||||||
|  |           default: rsa | ||||||
|       required: |       required: | ||||||
|       - common_name |       - common_name | ||||||
|       - validity_days |       - validity_days | ||||||
| @ -43658,6 +43675,10 @@ components: | |||||||
|       - http://www.w3.org/2001/04/xmldsig-more#rsa-sha256 |       - http://www.w3.org/2001/04/xmldsig-more#rsa-sha256 | ||||||
|       - http://www.w3.org/2001/04/xmldsig-more#rsa-sha384 |       - http://www.w3.org/2001/04/xmldsig-more#rsa-sha384 | ||||||
|       - http://www.w3.org/2001/04/xmldsig-more#rsa-sha512 |       - http://www.w3.org/2001/04/xmldsig-more#rsa-sha512 | ||||||
|  |       - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha1 | ||||||
|  |       - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256 | ||||||
|  |       - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha384 | ||||||
|  |       - http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha512 | ||||||
|       - http://www.w3.org/2000/09/xmldsig#dsa-sha1 |       - http://www.w3.org/2000/09/xmldsig#dsa-sha1 | ||||||
|       type: string |       type: string | ||||||
|     Source: |     Source: | ||||||
|  | |||||||
							
								
								
									
										1934
									
								
								web/package-lock.json
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										1934
									
								
								web/package-lock.json
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -38,7 +38,7 @@ | |||||||
|         "@codemirror/theme-one-dark": "^6.1.2", |         "@codemirror/theme-one-dark": "^6.1.2", | ||||||
|         "@formatjs/intl-listformat": "^7.5.5", |         "@formatjs/intl-listformat": "^7.5.5", | ||||||
|         "@fortawesome/fontawesome-free": "^6.5.2", |         "@fortawesome/fontawesome-free": "^6.5.2", | ||||||
|         "@goauthentik/api": "^2024.2.3-1713441634", |         "@goauthentik/api": "^2024.4.1-1714655911", | ||||||
|         "@lit-labs/task": "^3.1.0", |         "@lit-labs/task": "^3.1.0", | ||||||
|         "@lit/context": "^1.1.1", |         "@lit/context": "^1.1.1", | ||||||
|         "@lit/localize": "^0.12.1", |         "@lit/localize": "^0.12.1", | ||||||
|  | |||||||
| @ -29,5 +29,9 @@ export const signatureAlgorithmOptions = toOptions([ | |||||||
|     ["RSA-SHA256", SignatureAlgorithmEnum._200104XmldsigMorersaSha256, true], |     ["RSA-SHA256", SignatureAlgorithmEnum._200104XmldsigMorersaSha256, true], | ||||||
|     ["RSA-SHA384", SignatureAlgorithmEnum._200104XmldsigMorersaSha384], |     ["RSA-SHA384", SignatureAlgorithmEnum._200104XmldsigMorersaSha384], | ||||||
|     ["RSA-SHA512", SignatureAlgorithmEnum._200104XmldsigMorersaSha512], |     ["RSA-SHA512", SignatureAlgorithmEnum._200104XmldsigMorersaSha512], | ||||||
|  |     ["ECDSA-SHA1", SignatureAlgorithmEnum._200104XmldsigMoreecdsaSha1], | ||||||
|  |     ["ECDSA-SHA256", SignatureAlgorithmEnum._200104XmldsigMoreecdsaSha256], | ||||||
|  |     ["ECDSA-SHA384", SignatureAlgorithmEnum._200104XmldsigMoreecdsaSha384], | ||||||
|  |     ["ECDSA-SHA512", SignatureAlgorithmEnum._200104XmldsigMoreecdsaSha512], | ||||||
|     ["DSA-SHA1", SignatureAlgorithmEnum._200009XmldsigdsaSha1], |     ["DSA-SHA1", SignatureAlgorithmEnum._200009XmldsigdsaSha1], | ||||||
| ]); | ]); | ||||||
|  | |||||||
| @ -6,7 +6,12 @@ import { msg } from "@lit/localize"; | |||||||
| import { TemplateResult, html } from "lit"; | import { TemplateResult, html } from "lit"; | ||||||
| import { customElement } from "lit/decorators.js"; | import { customElement } from "lit/decorators.js"; | ||||||
|  |  | ||||||
| import { CertificateGenerationRequest, CertificateKeyPair, CryptoApi } from "@goauthentik/api"; | import { | ||||||
|  |     AlgEnum, | ||||||
|  |     CertificateGenerationRequest, | ||||||
|  |     CertificateKeyPair, | ||||||
|  |     CryptoApi, | ||||||
|  | } from "@goauthentik/api"; | ||||||
|  |  | ||||||
| @customElement("ak-crypto-certificate-generate-form") | @customElement("ak-crypto-certificate-generate-form") | ||||||
| export class CertificateKeyPairForm extends Form<CertificateGenerationRequest> { | export class CertificateKeyPairForm extends Form<CertificateGenerationRequest> { | ||||||
| @ -40,6 +45,29 @@ export class CertificateKeyPairForm extends Form<CertificateGenerationRequest> { | |||||||
|                 ?required=${true} |                 ?required=${true} | ||||||
|             > |             > | ||||||
|                 <input class="pf-c-form-control" type="number" value="365" /> |                 <input class="pf-c-form-control" type="number" value="365" /> | ||||||
|  |             </ak-form-element-horizontal> | ||||||
|  |             <ak-form-element-horizontal | ||||||
|  |                 label=${msg("Private key Algorithm")} | ||||||
|  |                 ?required=${true} | ||||||
|  |                 name="alg" | ||||||
|  |             > | ||||||
|  |                 <ak-radio | ||||||
|  |                     .options=${[ | ||||||
|  |                         { | ||||||
|  |                             label: msg("RSA"), | ||||||
|  |                             value: AlgEnum.Rsa, | ||||||
|  |                             default: true, | ||||||
|  |                         }, | ||||||
|  |                         { | ||||||
|  |                             label: msg("ECDSA"), | ||||||
|  |                             value: AlgEnum.Ecdsa, | ||||||
|  |                         }, | ||||||
|  |                     ]} | ||||||
|  |                 > | ||||||
|  |                 </ak-radio> | ||||||
|  |                 <p class="pf-c-form__helper-text"> | ||||||
|  |                     ${msg("Algorithm used to generate the private key.")} | ||||||
|  |                 </p> | ||||||
|             </ak-form-element-horizontal> `; |             </ak-form-element-horizontal> `; | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -97,7 +97,7 @@ export class EventListPage extends TablePage<Event> { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     renderExpanded(item: Event): TemplateResult { |     renderExpanded(item: Event): TemplateResult { | ||||||
|         return html` <td role="cell" colspan="3"> |         return html` <td role="cell" colspan="5"> | ||||||
|                 <div class="pf-c-table__expandable-row-content"> |                 <div class="pf-c-table__expandable-row-content"> | ||||||
|                     <ak-event-info .event=${item as EventWithContext}></ak-event-info> |                     <ak-event-info .event=${item as EventWithContext}></ak-event-info> | ||||||
|                 </div> |                 </div> | ||||||
|  | |||||||
| @ -214,21 +214,16 @@ export class IdentificationStageForm extends BaseStageForm<IdentificationStage> | |||||||
|                         name="sources" |                         name="sources" | ||||||
|                     > |                     > | ||||||
|                         <select class="pf-c-form-control" multiple> |                         <select class="pf-c-form-control" multiple> | ||||||
|                             ${this.sources?.results.map((source) => { |                             ${this.sources?.results | ||||||
|                                 let selected = Array.from(this.instance?.sources || []).some( |                                 .filter((source) => { | ||||||
|  |                                     return source.component !== ""; | ||||||
|  |                                 }) | ||||||
|  |                                 .map((source) => { | ||||||
|  |                                     const selected = Array.from(this.instance?.sources || []).some( | ||||||
|                                         (su) => { |                                         (su) => { | ||||||
|                                             return su == source.pk; |                                             return su == source.pk; | ||||||
|                                         }, |                                         }, | ||||||
|                                     ); |                                     ); | ||||||
|                                 // Creating a new instance, auto-select built-in source |  | ||||||
|                                 // Only when no other sources exist |  | ||||||
|                                 if ( |  | ||||||
|                                     !this.instance && |  | ||||||
|                                     source.component === "" && |  | ||||||
|                                     (this.sources?.results || []).length < 2 |  | ||||||
|                                 ) { |  | ||||||
|                                     selected = true; |  | ||||||
|                                 } |  | ||||||
|                                     return html`<option |                                     return html`<option | ||||||
|                                         value=${ifDefined(source.pk)} |                                         value=${ifDefined(source.pk)} | ||||||
|                                         ?selected=${selected} |                                         ?selected=${selected} | ||||||
|  | |||||||
| @ -128,6 +128,14 @@ export class UserForm extends ModelForm<User, number> { | |||||||
|                                 "Service accounts should be used for machine-to-machine authentication or other automations.", |                                 "Service accounts should be used for machine-to-machine authentication or other automations.", | ||||||
|                             )}`, |                             )}`, | ||||||
|                         }, |                         }, | ||||||
|  |                         { | ||||||
|  |                             label: "Internal Service account", | ||||||
|  |                             value: UserTypeEnum.InternalServiceAccount, | ||||||
|  |                             disabled: true, | ||||||
|  |                             description: html`${msg( | ||||||
|  |                                 "Internal Service accounts are created and managed by authentik and cannot be created manually.", | ||||||
|  |                             )}`, | ||||||
|  |                         }, | ||||||
|                     ]} |                     ]} | ||||||
|                     .value=${this.instance?.type} |                     .value=${this.instance?.type} | ||||||
|                 > |                 > | ||||||
|  | |||||||
| @ -3,7 +3,7 @@ export const SUCCESS_CLASS = "pf-m-success"; | |||||||
| export const ERROR_CLASS = "pf-m-danger"; | export const ERROR_CLASS = "pf-m-danger"; | ||||||
| export const PROGRESS_CLASS = "pf-m-in-progress"; | export const PROGRESS_CLASS = "pf-m-in-progress"; | ||||||
| export const CURRENT_CLASS = "pf-m-current"; | export const CURRENT_CLASS = "pf-m-current"; | ||||||
| export const VERSION = "2024.2.3"; | export const VERSION = "2024.4.4"; | ||||||
| export const TITLE_DEFAULT = "authentik"; | export const TITLE_DEFAULT = "authentik"; | ||||||
| export const ROUTE_SEPARATOR = ";"; | export const ROUTE_SEPARATOR = ";"; | ||||||
|  |  | ||||||
|  | |||||||
| @ -187,6 +187,9 @@ input[type="date"]::-webkit-calendar-picker-indicator { | |||||||
| .pf-c-select__menu-item.pf-m-focus { | .pf-c-select__menu-item.pf-m-focus { | ||||||
|     --pf-c-select__menu-item--focus--BackgroundColor: var(--ak-dark-background-light-ish); |     --pf-c-select__menu-item--focus--BackgroundColor: var(--ak-dark-background-light-ish); | ||||||
| } | } | ||||||
|  | .pf-c-button:disabled { | ||||||
|  |     color: var(--ak-dark-background-lighter); | ||||||
|  | } | ||||||
| .pf-c-button.pf-m-plain:hover { | .pf-c-button.pf-m-plain:hover { | ||||||
|     color: var(--ak-dark-foreground); |     color: var(--ak-dark-foreground); | ||||||
| } | } | ||||||
|  | |||||||
| @ -18,7 +18,7 @@ export function me(): Promise<SessionUser> { | |||||||
|                 if (!user.user.settings || !("locale" in user.user.settings)) { |                 if (!user.user.settings || !("locale" in user.user.settings)) { | ||||||
|                     return user; |                     return user; | ||||||
|                 } |                 } | ||||||
|                 const locale = user.user.settings.locale; |                 const locale: string | undefined = user.user.settings.locale; | ||||||
|                 if (locale && locale !== "") { |                 if (locale && locale !== "") { | ||||||
|                     console.debug( |                     console.debug( | ||||||
|                         `authentik/locale: Activating user's configured locale '${locale}'`, |                         `authentik/locale: Activating user's configured locale '${locale}'`, | ||||||
|  | |||||||
| @ -111,6 +111,21 @@ export function dateTimeLocal(date: Date): string { | |||||||
|     return `${parts[0]}:${parts[1]}`; |     return `${parts[0]}:${parts[1]}`; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | export function dateToUTC(date: Date): Date { | ||||||
|  |     // Sigh...so our API is UTC/can take TZ info in the ISO format as it should. | ||||||
|  |     // datetime-local fields (which is almost the only date-time input we use) | ||||||
|  |     // can return its value as a UTC timestamp...however the generated API client | ||||||
|  |     // _requires_ a Date object, only to then convert it to an ISO string anyways | ||||||
|  |     // JS Dates don't include timezone info in the ISO string, so that just sends | ||||||
|  |     // the local time as UTC...which is wrong | ||||||
|  |     // Instead we have to do this, convert the given date to a UTC timestamp, | ||||||
|  |     // then subtract the timezone offset to create an "invalid" date (correct time&date) | ||||||
|  |     // but it still "thinks" it's in local TZ | ||||||
|  |     const timestamp = date.getTime(); | ||||||
|  |     const offset = -1 * (new Date().getTimezoneOffset() * 60000); | ||||||
|  |     return new Date(timestamp - offset); | ||||||
|  | } | ||||||
|  |  | ||||||
| // Lit is extremely well-typed with regard to CSS, and Storybook's `build` does not currently have a | // Lit is extremely well-typed with regard to CSS, and Storybook's `build` does not currently have a | ||||||
| // coherent way of importing CSS-as-text into CSSStyleSheet. It works well when Storybook is running | // coherent way of importing CSS-as-text into CSSStyleSheet. It works well when Storybook is running | ||||||
| // in `dev,` but in `build` it fails. Storied components will have to map their textual CSS imports | // in `dev,` but in `build` it fails. Storied components will have to map their textual CSS imports | ||||||
|  | |||||||
| @ -18,6 +18,7 @@ import PFDescriptionList from "@patternfly/patternfly/components/DescriptionList | |||||||
| import PFList from "@patternfly/patternfly/components/List/list.css"; | import PFList from "@patternfly/patternfly/components/List/list.css"; | ||||||
| import PFTable from "@patternfly/patternfly/components/Table/table.css"; | import PFTable from "@patternfly/patternfly/components/Table/table.css"; | ||||||
| import PFFlex from "@patternfly/patternfly/layouts/Flex/flex.css"; | import PFFlex from "@patternfly/patternfly/layouts/Flex/flex.css"; | ||||||
|  | import PFSplit from "@patternfly/patternfly/layouts/Split/split.css"; | ||||||
| import PFBase from "@patternfly/patternfly/patternfly-base.css"; | import PFBase from "@patternfly/patternfly/patternfly-base.css"; | ||||||
|  |  | ||||||
| import { EventActions, FlowsApi } from "@goauthentik/api"; | import { EventActions, FlowsApi } from "@goauthentik/api"; | ||||||
| @ -81,6 +82,7 @@ export class EventInfo extends AKElement { | |||||||
|             PFCard, |             PFCard, | ||||||
|             PFTable, |             PFTable, | ||||||
|             PFList, |             PFList, | ||||||
|  |             PFSplit, | ||||||
|             PFDescriptionList, |             PFDescriptionList, | ||||||
|             css` |             css` | ||||||
|                 code { |                 code { | ||||||
| @ -246,11 +248,17 @@ export class EventInfo extends AKElement { | |||||||
|  |  | ||||||
|     renderModelChanged() { |     renderModelChanged() { | ||||||
|         const diff = this.event.context.diff as unknown as { |         const diff = this.event.context.diff as unknown as { | ||||||
|             [key: string]: { new_value: unknown; previous_value: unknown }; |             [key: string]: { | ||||||
|  |                 new_value: unknown; | ||||||
|  |                 previous_value: unknown; | ||||||
|  |                 add?: unknown[]; | ||||||
|  |                 remove?: unknown[]; | ||||||
|  |                 clear?: boolean; | ||||||
|  |             }; | ||||||
|         }; |         }; | ||||||
|         let diffBody = html``; |         let diffBody = html``; | ||||||
|         if (diff) { |         if (diff) { | ||||||
|             diffBody = html`<div class="pf-l-flex__item"> |             diffBody = html`<div class="pf-l-split__item pf-m-fill"> | ||||||
|                     <div class="pf-c-card__title">${msg("Changes made:")}</div> |                     <div class="pf-c-card__title">${msg("Changes made:")}</div> | ||||||
|                     <table class="pf-c-table pf-m-compact pf-m-grid-md" role="grid"> |                     <table class="pf-c-table pf-m-compact pf-m-grid-md" role="grid"> | ||||||
|                         <thead> |                         <thead> | ||||||
| @ -262,16 +270,36 @@ export class EventInfo extends AKElement { | |||||||
|                         </thead> |                         </thead> | ||||||
|                         <tbody role="rowgroup"> |                         <tbody role="rowgroup"> | ||||||
|                             ${Object.keys(diff).map((key) => { |                             ${Object.keys(diff).map((key) => { | ||||||
|  |                                 const value = diff[key]; | ||||||
|  |                                 const previousCol = value.previous_value | ||||||
|  |                                     ? JSON.stringify(value.previous_value, null, 4) | ||||||
|  |                                     : msg("-"); | ||||||
|  |                                 let newCol = html``; | ||||||
|  |                                 if (value.add || value.remove) { | ||||||
|  |                                     newCol = html`<ul class="pf-c-list"> | ||||||
|  |                                         ${(value.add || value.remove)?.map((item) => { | ||||||
|  |                                             let itemLabel = ""; | ||||||
|  |                                             if (value.add) { | ||||||
|  |                                                 itemLabel = msg(str`Added ID ${item}`); | ||||||
|  |                                             } else if (value.remove) { | ||||||
|  |                                                 itemLabel = msg(str`Removed ID ${item}`); | ||||||
|  |                                             } | ||||||
|  |                                             return html`<li>${itemLabel}</li>`; | ||||||
|  |                                         })} | ||||||
|  |                                     </ul>`; | ||||||
|  |                                 } else if (value.clear) { | ||||||
|  |                                     newCol = html`${msg("Cleared")}`; | ||||||
|  |                                 } else { | ||||||
|  |                                     newCol = html`<pre> | ||||||
|  | ${JSON.stringify(value.new_value, null, 4)}</pre | ||||||
|  |                                     >`; | ||||||
|  |                                 } | ||||||
|                                 return html` <tr role="row"> |                                 return html` <tr role="row"> | ||||||
|                                     <td role="cell"><pre>${key}</pre></td> |                                     <td role="cell"><pre>${key}</pre></td> | ||||||
|                                     <td role="cell"> |                                     <td role="cell"> | ||||||
|                                         <pre> |                                         <pre>${previousCol}</pre> | ||||||
| ${JSON.stringify(diff[key].previous_value, null, 4)}</pre |  | ||||||
|                                         > |  | ||||||
|                                     </td> |  | ||||||
|                                     <td role="cell"> |  | ||||||
|                                         <pre>${JSON.stringify(diff[key].new_value, null, 4)}</pre> |  | ||||||
|                                     </td> |                                     </td> | ||||||
|  |                                     <td role="cell">${newCol}</td> | ||||||
|                                 </tr>`; |                                 </tr>`; | ||||||
|                             })} |                             })} | ||||||
|                         </tbody> |                         </tbody> | ||||||
| @ -280,8 +308,8 @@ ${JSON.stringify(diff[key].previous_value, null, 4)}</pre | |||||||
|                 </div>`; |                 </div>`; | ||||||
|         } |         } | ||||||
|         return html` |         return html` | ||||||
|             <div class="pf-l-flex"> |             <div class="pf-l-split"> | ||||||
|                 <div class="pf-l-flex__item"> |                 <div class="pf-l-split__item pf-m-fill"> | ||||||
|                     <div class="pf-c-card__title">${msg("Affected model:")}</div> |                     <div class="pf-c-card__title">${msg("Affected model:")}</div> | ||||||
|                     <div class="pf-c-card__body"> |                     <div class="pf-c-card__body"> | ||||||
|                         ${this.getModelInfo(this.event.context?.model as EventModel)} |                         ${this.getModelInfo(this.event.context?.model as EventModel)} | ||||||
|  | |||||||
| @ -87,7 +87,7 @@ export class Markdown extends AKElement { | |||||||
|             const parsedContent = matter(this.md); |             const parsedContent = matter(this.md); | ||||||
|             const parsedHTML = this.converter.makeHtml(parsedContent.content); |             const parsedHTML = this.converter.makeHtml(parsedContent.content); | ||||||
|             const replacers = [...this.defaultReplacers, ...this.replacers]; |             const replacers = [...this.defaultReplacers, ...this.replacers]; | ||||||
|             this.docTitle = parsedContent.data["title"] ?? ""; |             this.docTitle = parsedContent?.data?.title ?? ""; | ||||||
|             this.docHtml = replacers.reduce( |             this.docHtml = replacers.reduce( | ||||||
|                 (html, replacer) => replacer(html, { path: this.meta }), |                 (html, replacer) => replacer(html, { path: this.meta }), | ||||||
|                 parsedHTML, |                 parsedHTML, | ||||||
|  | |||||||
| @ -13,7 +13,7 @@ import { WithBrandConfig } from "@goauthentik/elements/Interface/brandProvider"; | |||||||
| import "@patternfly/elements/pf-tooltip/pf-tooltip.js"; | import "@patternfly/elements/pf-tooltip/pf-tooltip.js"; | ||||||
|  |  | ||||||
| import { msg } from "@lit/localize"; | import { msg } from "@lit/localize"; | ||||||
| import { CSSResult, PropertyValues, TemplateResult, css, html } from "lit"; | import { CSSResult, TemplateResult, css, html } from "lit"; | ||||||
| import { customElement, property } from "lit/decorators.js"; | import { customElement, property } from "lit/decorators.js"; | ||||||
|  |  | ||||||
| import PFButton from "@patternfly/patternfly/components/Button/button.css"; | import PFButton from "@patternfly/patternfly/components/Button/button.css"; | ||||||
| @ -107,22 +107,24 @@ export class PageHeader extends WithBrandConfig(AKElement) { | |||||||
|         }); |         }); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     setTitle(value: string) { |     setTitle(header?: string) { | ||||||
|         const currentIf = currentInterface(); |         const currentIf = currentInterface(); | ||||||
|         const title = this.brand?.brandingTitle || TITLE_DEFAULT; |         let title = this.brand?.brandingTitle || TITLE_DEFAULT; | ||||||
|         document.title = |         if (currentIf === "admin") { | ||||||
|             currentIf === "admin" |             title = `${msg("Admin")} - ${title}`; | ||||||
|                 ? `${msg("Admin")} - ${title}` |         } | ||||||
|                 : value !== "" |         // Prepend the header to the title | ||||||
|                   ? `${value} - ${title}` |         if (header !== undefined && header !== "") { | ||||||
|                   : title; |             title = `${header} - ${title}`; | ||||||
|  |         } | ||||||
|  |         document.title = title; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     willUpdate(changedProperties: PropertyValues<this>) { |     willUpdate() { | ||||||
|         if (changedProperties.has("header") && this.header) { |         // Always update title, even if there's no header value set, | ||||||
|  |         // as in that case we still need to return to the generic title | ||||||
|         this.setTitle(this.header); |         this.setTitle(this.header); | ||||||
|     } |     } | ||||||
|     } |  | ||||||
|  |  | ||||||
|     renderIcon(): TemplateResult { |     renderIcon(): TemplateResult { | ||||||
|         if (this.icon) { |         if (this.icon) { | ||||||
|  | |||||||
| @ -2,8 +2,9 @@ import { AKElement } from "@goauthentik/elements/Base"; | |||||||
| import { CustomEmitterElement } from "@goauthentik/elements/utils/eventEmitter"; | import { CustomEmitterElement } from "@goauthentik/elements/utils/eventEmitter"; | ||||||
|  |  | ||||||
| import { msg } from "@lit/localize"; | import { msg } from "@lit/localize"; | ||||||
|  | import { PropertyValues } from "@lit/reactive-element/reactive-element"; | ||||||
| import { TemplateResult, css, html } from "lit"; | import { TemplateResult, css, html } from "lit"; | ||||||
| import { customElement, property, queryAll } from "lit/decorators.js"; | import { customElement, property, queryAll, state } from "lit/decorators.js"; | ||||||
| import { map } from "lit/directives/map.js"; | import { map } from "lit/directives/map.js"; | ||||||
|  |  | ||||||
| import PFCheck from "@patternfly/patternfly/components/Check/check.css"; | import PFCheck from "@patternfly/patternfly/components/Check/check.css"; | ||||||
| @ -112,10 +113,14 @@ export class CheckboxGroup extends AkElementWithCustomEvents { | |||||||
|     @queryAll('input[type="checkbox"]') |     @queryAll('input[type="checkbox"]') | ||||||
|     checkboxes!: NodeListOf<HTMLInputElement>; |     checkboxes!: NodeListOf<HTMLInputElement>; | ||||||
|  |  | ||||||
|     internals?: ElementInternals; |     @state() | ||||||
|  |     values: string[] = []; | ||||||
|  |  | ||||||
|     get json() { |     internals?: ElementInternals; | ||||||
|         return this.value; |     doneFirstUpdate = false; | ||||||
|  |  | ||||||
|  |     json() { | ||||||
|  |         return this.values; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     private get formValue() { |     private get formValue() { | ||||||
| @ -124,7 +129,7 @@ export class CheckboxGroup extends AkElementWithCustomEvents { | |||||||
|         } |         } | ||||||
|         const name = this.name; |         const name = this.name; | ||||||
|         const entries = new FormData(); |         const entries = new FormData(); | ||||||
|         this.value.forEach((v) => entries.append(name, v)); |         this.values.forEach((v) => entries.append(name, v)); | ||||||
|         return entries; |         return entries; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @ -136,14 +141,14 @@ export class CheckboxGroup extends AkElementWithCustomEvents { | |||||||
|  |  | ||||||
|     onClick(ev: Event) { |     onClick(ev: Event) { | ||||||
|         ev.stopPropagation(); |         ev.stopPropagation(); | ||||||
|         this.value = Array.from(this.checkboxes) |         this.values = Array.from(this.checkboxes) | ||||||
|             .filter((checkbox) => checkbox.checked) |             .filter((checkbox) => checkbox.checked) | ||||||
|             .map((checkbox) => checkbox.name); |             .map((checkbox) => checkbox.name); | ||||||
|         this.dispatchCustomEvent("change", this.value); |         this.dispatchCustomEvent("change", this.values); | ||||||
|         this.dispatchCustomEvent("input", this.value); |         this.dispatchCustomEvent("input", this.values); | ||||||
|         if (this.internals) { |         if (this.internals) { | ||||||
|             this.internals.setValidity({}); |             this.internals.setValidity({}); | ||||||
|             if (this.required && this.value.length === 0) { |             if (this.required && this.values.length === 0) { | ||||||
|                 this.internals.setValidity( |                 this.internals.setValidity( | ||||||
|                     { |                     { | ||||||
|                         valueMissing: true, |                         valueMissing: true, | ||||||
| @ -154,6 +159,16 @@ export class CheckboxGroup extends AkElementWithCustomEvents { | |||||||
|             } |             } | ||||||
|             this.internals.setFormValue(this.formValue); |             this.internals.setFormValue(this.formValue); | ||||||
|         } |         } | ||||||
|  |         // Doing a write-back so anyone examining the checkbox.value field will get something | ||||||
|  |         // meaningful. Doesn't do anything for anyone, usually, but it's nice to have. | ||||||
|  |         this.value = this.values; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     willUpdate(changed: PropertyValues<this>) { | ||||||
|  |         if (changed.has("value") && !this.doneFirstUpdate) { | ||||||
|  |             this.doneFirstUpdate = true; | ||||||
|  |             this.values = this.value; | ||||||
|  |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     connectedCallback() { |     connectedCallback() { | ||||||
| @ -183,7 +198,7 @@ export class CheckboxGroup extends AkElementWithCustomEvents { | |||||||
|  |  | ||||||
|     render() { |     render() { | ||||||
|         const renderOne = ([name, label]: CheckboxPr) => { |         const renderOne = ([name, label]: CheckboxPr) => { | ||||||
|             const selected = this.value.includes(name); |             const selected = this.values.includes(name); | ||||||
|             const blockFwd = (e: Event) => { |             const blockFwd = (e: Event) => { | ||||||
|                 e.stopImmediatePropagation(); |                 e.stopImmediatePropagation(); | ||||||
|             }; |             }; | ||||||
|  | |||||||
| @ -53,6 +53,9 @@ export class AkDualSelectProvider extends CustomListenerElement(AKElement) { | |||||||
|  |  | ||||||
|     private isLoading = false; |     private isLoading = false; | ||||||
|  |  | ||||||
|  |     private doneFirstUpdate = false; | ||||||
|  |     private internalSelected: DualSelectPair[] = []; | ||||||
|  |  | ||||||
|     private pagination?: Pagination; |     private pagination?: Pagination; | ||||||
|  |  | ||||||
|     constructor() { |     constructor() { | ||||||
| @ -69,6 +72,11 @@ export class AkDualSelectProvider extends CustomListenerElement(AKElement) { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     willUpdate(changedProperties: PropertyValues<this>) { |     willUpdate(changedProperties: PropertyValues<this>) { | ||||||
|  |         if (changedProperties.has("selected") && !this.doneFirstUpdate) { | ||||||
|  |             this.doneFirstUpdate = true; | ||||||
|  |             this.internalSelected = this.selected; | ||||||
|  |         } | ||||||
|  |  | ||||||
|         if (changedProperties.has("searchDelay")) { |         if (changedProperties.has("searchDelay")) { | ||||||
|             this.doSearch = debounce( |             this.doSearch = debounce( | ||||||
|                 AkDualSelectProvider.prototype.doSearch.bind(this), |                 AkDualSelectProvider.prototype.doSearch.bind(this), | ||||||
| @ -105,7 +113,8 @@ export class AkDualSelectProvider extends CustomListenerElement(AKElement) { | |||||||
|         if (!(event instanceof CustomEvent)) { |         if (!(event instanceof CustomEvent)) { | ||||||
|             throw new Error(`Expecting a CustomEvent for change, received ${event} instead`); |             throw new Error(`Expecting a CustomEvent for change, received ${event} instead`); | ||||||
|         } |         } | ||||||
|         this.selected = event.detail.value; |         this.internalSelected = event.detail.value; | ||||||
|  |         this.selected = this.internalSelected; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     onSearch(event: Event) { |     onSearch(event: Event) { | ||||||
| @ -124,12 +133,16 @@ export class AkDualSelectProvider extends CustomListenerElement(AKElement) { | |||||||
|         return this.dualSelector.value!.selected.map(([k, _]) => k); |         return this.dualSelector.value!.selected.map(([k, _]) => k); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     json() { | ||||||
|  |         return this.value; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     render() { |     render() { | ||||||
|         return html`<ak-dual-select |         return html`<ak-dual-select | ||||||
|             ${ref(this.dualSelector)} |             ${ref(this.dualSelector)} | ||||||
|             .options=${this.options} |             .options=${this.options} | ||||||
|             .pages=${this.pagination} |             .pages=${this.pagination} | ||||||
|             .selected=${this.selected} |             .selected=${this.internalSelected} | ||||||
|             available-label=${this.availableLabel} |             available-label=${this.availableLabel} | ||||||
|             selected-label=${this.selectedLabel} |             selected-label=${this.selectedLabel} | ||||||
|         ></ak-dual-select>`; |         ></ak-dual-select>`; | ||||||
|  | |||||||
| @ -2,7 +2,7 @@ import { EVENT_LOCALE_CHANGE, EVENT_LOCALE_REQUEST } from "@goauthentik/common/c | |||||||
| import { customEvent } from "@goauthentik/elements/utils/customEvents"; | import { customEvent } from "@goauthentik/elements/utils/customEvents"; | ||||||
|  |  | ||||||
| import { LitElement, html } from "lit"; | import { LitElement, html } from "lit"; | ||||||
| import { customElement, property, state } from "lit/decorators.js"; | import { customElement, property } from "lit/decorators.js"; | ||||||
|  |  | ||||||
| import { WithBrandConfig } from "../Interface/brandProvider"; | import { WithBrandConfig } from "../Interface/brandProvider"; | ||||||
| import { initializeLocalization } from "./configureLocale"; | import { initializeLocalization } from "./configureLocale"; | ||||||
| @ -38,9 +38,6 @@ export class LocaleContext extends LocaleContextBase { | |||||||
|  |  | ||||||
|     setLocale: LocaleSetter; |     setLocale: LocaleSetter; | ||||||
|  |  | ||||||
|     @state() |  | ||||||
|     userLocale = ""; |  | ||||||
|  |  | ||||||
|     constructor(code = DEFAULT_LOCALE) { |     constructor(code = DEFAULT_LOCALE) { | ||||||
|         super(); |         super(); | ||||||
|         this.notifyApplication = this.notifyApplication.bind(this); |         this.notifyApplication = this.notifyApplication.bind(this); | ||||||
| @ -59,30 +56,22 @@ export class LocaleContext extends LocaleContextBase { | |||||||
|  |  | ||||||
|     connectedCallback() { |     connectedCallback() { | ||||||
|         super.connectedCallback(); |         super.connectedCallback(); | ||||||
|         // Commenting out until we can come up with a better way of separating the |  | ||||||
|         // "request user identity" with the session expiration heartbeat. |  | ||||||
|         /* |  | ||||||
|             new CoreApi(DEFAULT_CONFIG) |  | ||||||
|                 .coreUsersMeRetrieve() |  | ||||||
|                 .then((user) => (this.userLocale = user?.user?.settings?.locale ?? "")) |  | ||||||
|                 .catch(() => {}); |  | ||||||
|         */ |  | ||||||
|         this.updateLocale(); |         this.updateLocale(); | ||||||
|         window.addEventListener(EVENT_LOCALE_REQUEST, this.updateLocaleHandler); |         window.addEventListener(EVENT_LOCALE_REQUEST, this.updateLocaleHandler as EventListener); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     disconnectedCallback() { |     disconnectedCallback() { | ||||||
|         window.removeEventListener(EVENT_LOCALE_REQUEST, this.updateLocaleHandler); |         window.removeEventListener(EVENT_LOCALE_REQUEST, this.updateLocaleHandler as EventListener); | ||||||
|         super.disconnectedCallback(); |         super.disconnectedCallback(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     updateLocaleHandler(_ev: Event) { |     updateLocaleHandler(ev: CustomEvent<{ locale: string }>) { | ||||||
|         console.debug("authentik/locale: Locale update request received."); |         console.debug("authentik/locale: Locale update request received."); | ||||||
|         this.updateLocale(); |         this.updateLocale(ev.detail.locale); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     updateLocale() { |     updateLocale(requestedLocale: string | undefined = undefined) { | ||||||
|         const localeRequest = autoDetectLanguage(this.userLocale, this.brand?.defaultLocale); |         const localeRequest = autoDetectLanguage(requestedLocale, this.brand?.defaultLocale); | ||||||
|         const locale = getBestMatchLocale(localeRequest); |         const locale = getBestMatchLocale(localeRequest); | ||||||
|         if (!locale) { |         if (!locale) { | ||||||
|             console.warn(`authentik/locale: failed to find locale for code ${localeRequest}`); |             console.warn(`authentik/locale: failed to find locale for code ${localeRequest}`); | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	